refactor sampler/processor and a few improvements

This commit is contained in:
Awni Hannun 2024-11-05 17:01:21 -08:00
parent 3783156072
commit 0be87b3c53
9 changed files with 153 additions and 164 deletions

View File

@ -101,7 +101,8 @@ To see a description of all the arguments you can do:
#### Streaming #### Streaming
For streaming generation, use the `stream_generate` function. This returns a For streaming generation, use the `stream_generate` function. This returns a
generator object which streams the output text. For example, generator object which streams the output text, token, and log probabilities.
For example,
```python ```python
from mlx_lm import load, stream_generate from mlx_lm import load, stream_generate
@ -116,7 +117,7 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
for t in stream_generate(model, tokenizer, prompt, max_tokens=512): for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True) print(t, end="", flush=True)
print() print()
``` ```

View File

@ -152,6 +152,7 @@ def main():
model(y[:step_size][None], cache=cache) model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache]) mx.eval([c.state for c in cache])
mx.metal.clear_cache()
processed += min(y.size, step_size) processed += min(y.size, step_size)
y = y[step_size:] y = y[step_size:]
current = time.time() current = time.time()
@ -165,14 +166,13 @@ def main():
) )
print() print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
print("Saving...") print("Saving...")
metadata = {} metadata = {}
metadata["model"] = args.model metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config) metadata["tokenizer_config"] = json.dumps(tokenizer_config)
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
save_prompt_cache(args.prompt_cache_file, cache, metadata) save_prompt_cache(args.prompt_cache_file, cache, metadata)

View File

@ -13,6 +13,8 @@ DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100 DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.0 DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0 DEFAULT_TOP_P = 1.0
DEFAULT_MIN_P = 0.0
DEFAULT_MIN_TOKENS_TO_KEEP = 1
DEFAULT_SEED = 0 DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_QUANTIZED_KV_START = 5000 DEFAULT_QUANTIZED_KV_START = 5000
@ -52,6 +54,7 @@ def setup_arg_parser():
) )
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
"-p",
default=DEFAULT_PROMPT, default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)", help="Message to be processed by the model ('-' reads from stdin)",
) )
@ -68,6 +71,15 @@ def setup_arg_parser():
parser.add_argument( parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
) )
parser.add_argument(
"--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p"
)
parser.add_argument(
"--min-tokens-to-keep",
type=float,
default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.",
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument( parser.add_argument(
"--ignore-chat-template", "--ignore-chat-template",
@ -238,8 +250,6 @@ def main():
raise ValueError("Cannot use --colorize with --verbose=False") raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None formatter = colorprint_by_t0 if args.colorize else None
sampler = make_sampler(
args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
response = generate( response = generate(
model, model,
tokenizer, tokenizer,
@ -247,7 +257,10 @@ def main():
args.max_tokens, args.max_tokens,
verbose=args.verbose, verbose=args.verbose,
formatter=formatter, formatter=formatter,
sampler=sampler, temp=args.temp,
top_p=args.top_p,
min_p=args.min_p,
min_tokens_to_keep=args.min_tokens_to_keep,
max_kv_size=args.max_kv_size, max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None, prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits, kv_bits=args.kv_bits,

View File

@ -1,6 +1,7 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
from functools import partial from functools import partial
from typing import Callable, Dict, Optional
import mlx.core as mx import mlx.core as mx
@ -25,7 +26,7 @@ def make_sampler(
be filtered by min_p sampling. be filtered by min_p sampling.
Returns: Returns:
Callabel[mx.array, mx.array]: Callable[mx.array, mx.array]:
A sampler which takes log-probabilities and returns tokens. A sampler which takes log-probabilities and returns tokens.
""" """
if temp == 0: if temp == 0:
@ -38,7 +39,11 @@ def make_sampler(
return lambda x: categorical_sampling(x, temp) return lambda x: categorical_sampling(x, temp)
def make_logits_processors(): def make_logits_processors(
logit_bias: Optional[Dict[int, float]] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
):
""" """
Make logits processors for use with ``generate_step``. Make logits processors for use with ``generate_step``.
@ -48,8 +53,13 @@ def make_logits_processors():
repetition_context_size (int, optional): The number of tokens to repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``. consider for repetition penalty. Default: ``20``.
logit_bias (dictionary, optional): Additive logit bias. logit_bias (dictionary, optional): Additive logit bias.
"""
Returns:
List[Callable[[mx.array, mx.array], mx.array]]:
A list of logits processors. Each processor in the list is a
callable which takes an array of tokens and an array of logits
and returns the updated logits.
"""
logits_processors = [] logits_processors = []
if logit_bias: if logit_bias:
indices = mx.array(list(logit_bias.keys())) indices = mx.array(list(logit_bias.keys()))
@ -61,6 +71,12 @@ def make_logits_processors():
logits_processors.append(logit_bias_processor) logits_processors.append(logit_bias_processor)
if repetition_penalty and repetition_penalty != 0.0:
logits_processors.append(
make_repetition_penalty(repetition_penalty, repetition_context_size)
)
return logits_processors
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling( def min_p_sampling(
@ -159,7 +175,7 @@ def categorical_sampling(logits, temp):
return mx.random.categorical(logits * (1 / temp)) return mx.random.categorical(logits * (1 / temp))
def repetition_penalty(penalty: float, context_size: int = 20): def make_repetition_penalty(penalty: float, context_size: int = 20):
""" """
Make repetition penalty processor. Make repetition penalty processor.
@ -177,7 +193,7 @@ def repetition_penalty(penalty: float, context_size: int = 20):
if penalty < 0 or not isinstance(penalty, float): if penalty < 0 or not isinstance(penalty, float):
raise ValueError(f"penalty must be a non-negative float, got {penalty}") raise ValueError(f"penalty must be a non-negative float, got {penalty}")
def repetition_penalty_processor(logits, tokens): def repetition_penalty_processor(tokens, logits):
if len(tokens) > 0: if len(tokens) > 0:
tokens = tokens[-context_size:] tokens = tokens[-context_size:]
selected_logits = logits[:, tokens] selected_logits = logits[:, tokens]

View File

@ -27,7 +27,7 @@ from huggingface_hub import scan_cache_dir
from ._version import __version__ from ._version import __version__
from .models.cache import make_prompt_cache from .models.cache import make_prompt_cache
from .utils import generate_step, load from .utils import load, stream_generate
def get_system_fingerprint(): def get_system_fingerprint():
@ -290,10 +290,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Call endpoint specific method # Call endpoint specific method
prompt = endpoints[self.path]() prompt = endpoints[self.path]()
self.handle_completion(prompt, stop_id_sequences)
# Call method based on response type
method = self.handle_stream if self.stream else self.handle_completion
method(prompt, stop_id_sequences)
def validate_model_parameters(self): def validate_model_parameters(self):
""" """
@ -452,25 +449,28 @@ class APIHandler(BaseHTTPRequestHandler):
stop_id_sequences (List[List[int]]): A list of stop words passed stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function to the stopping_criteria function
""" """
detokenizer = self.tokenizer.detokenizer
detokenizer.reset()
tokens = [] tokens = []
finish_reason = "length" finish_reason = "length"
stop_sequence_suffix = None stop_sequence_suffix = None
logging.debug(f"Starting completion:") if self.stream:
self.end_headers()
logging.debug(f"Starting stream:")
else:
logging.debug(f"Starting completion:")
token_logprobs = [] token_logprobs = []
top_tokens = [] top_tokens = []
prompt = mx.array(self.get_prompt_cache(prompt)) prompt = self.get_prompt_cache(prompt)
text = ""
tic = time.perf_counter() tic = time.perf_counter()
for _, (token, logprobs) in zip( for n, (segment, token, logprobs) in enumerate(
range(self.max_tokens), stream_generate(
generate_step(
prompt=prompt,
model=self.model, model=self.model,
tokenizer=self.tokenizer,
prompt=prompt,
max_tokens=self.max_tokens,
temp=self.temperature, temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty, repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size, repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
@ -481,8 +481,8 @@ class APIHandler(BaseHTTPRequestHandler):
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
tic = time.perf_counter() tic = time.perf_counter()
detokenizer.add_token(token) text += segment
logging.debug(detokenizer.text) logging.debug(text)
tokens.append(token) tokens.append(token)
if self.logprobs > 0: if self.logprobs > 0:
@ -503,128 +503,63 @@ class APIHandler(BaseHTTPRequestHandler):
stop_sequence_suffix = self.tokenizer.decode( stop_sequence_suffix = self.tokenizer.decode(
tokens[-stop_condition.trim_length :] tokens[-stop_condition.trim_length :]
) )
text = text[: -len(stop_sequence_suffix)]
break break
if self.stream:
# If the end of tokens overlaps with a stop sequence, generate new
# tokens until we know if the stop sequence is hit or not
if any(
(
sequence_overlap(tokens, sequence)
for sequence in stop_id_sequences
)
):
continue
elif segment:
response = self.generate_response(segment, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.prompt_cache.tokens.extend(tokens) self.prompt_cache.tokens.extend(tokens)
detokenizer.finalize()
text = (
detokenizer.text
if stop_sequence_suffix is None
else detokenizer.text[: -len(stop_sequence_suffix)]
)
gen_time = time.perf_counter() - tic gen_time = time.perf_counter() - tic
prompt_tps = len(prompt) / prompt_time prompt_tps = len(prompt) / prompt_time
gen_tps = len(tokens) / gen_time gen_tps = len(tokens) / gen_time
peak_mem = mx.metal.get_peak_memory() / 1e9 peak_mem = mx.metal.get_peak_memory() / 1e9
response = self.generate_response(
text,
finish_reason,
len(prompt),
len(tokens),
token_logprobs=token_logprobs,
top_tokens=top_tokens,
tokens=tokens,
)
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec") logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec") logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
logging.debug(f"Peak memory: {peak_mem:.3f} GB") logging.debug(f"Peak memory: {peak_mem:.3f} GB")
response_json = json.dumps(response).encode()
indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
# Send an additional Content-Length header when it is known if self.stream:
self.send_header("Content-Length", str(len(response_json))) response = self.generate_response(segment, finish_reason)
self.end_headers()
self.wfile.write(response_json)
self.wfile.flush()
def handle_stream(
self,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
Generate response to prompt and foward it to the client using a Server
Sent Events (SSE) stream.
Args:
prompt (mx.array): The tokenized prompt
stop_id_sequences (List[List[int]]): A list of stop words passed to
the stopping_criteria function
"""
# No additional headers are needed, call end_headers
self.end_headers()
detokenizer = self.tokenizer.detokenizer
detokenizer.reset()
tokens = []
stop_sequence_suffix = None
logging.debug(f"Starting stream:")
prompt = mx.array(self.get_prompt_cache(prompt))
for _, (token, _) in zip(
range(self.max_tokens),
generate_step(
prompt=prompt,
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
prompt_cache=self.prompt_cache.cache,
),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
stop_condition = stopping_criteria(
tokens,
stop_id_sequences,
self.tokenizer.eos_token_id,
)
if stop_condition.stop_met:
if stop_condition.trim_length:
stop_sequence_suffix = self.tokenizer.decode(
tokens[-stop_condition.trim_length :]
)
break
# If the end of tokens overlaps with a stop sequence, generate new
# tokens until we know if the stop sequence is hit or not
if any(
(sequence_overlap(tokens, sequence) for sequence in stop_id_sequences)
):
continue
new_text = detokenizer.last_segment
if new_text:
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.prompt_cache.tokens.extend(tokens)
# check is there any remaining text to send
detokenizer.finalize()
last_segment = detokenizer.last_segment
if last_segment:
if stop_sequence_suffix is not None:
last_segment = last_segment[: -len(stop_sequence_suffix)]
response = self.generate_response(last_segment, "length")
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush() self.wfile.flush()
if self.stream_options is not None and self.stream_options["include_usage"]:
response = self.completion_usage_response(len(prompt), len(tokens))
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.wfile.write("data: [DONE]\n\n".encode())
self.wfile.flush()
else:
response = self.generate_response(
text,
finish_reason,
len(prompt),
len(tokens),
token_logprobs=token_logprobs,
top_tokens=top_tokens,
tokens=tokens,
)
response_json = json.dumps(response).encode()
indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
if self.stream_options is not None and self.stream_options["include_usage"]: # Send an additional Content-Length header when it is known
response = self.completion_usage_response(len(prompt), len(tokens)) self.send_header("Content-Length", str(len(response_json)))
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.end_headers()
self.wfile.write(response_json)
self.wfile.write("data: [DONE]\n\n".encode()) self.wfile.flush()
self.wfile.flush()
def completion_usage_response( def completion_usage_response(
self, self,

View File

@ -285,7 +285,7 @@ def train(
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start) tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 2**30 peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0: if rank == 0:
print( print(
f"Iter {it}: Train loss {train_loss:.3f}, " f"Iter {it}: Train loss {train_loss:.3f}, "

View File

@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer
# Local imports # Local imports
from .models import cache from .models import cache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model from .tuner.utils import dequantize as dequantize_model
from .tuner.utils import load_adapters from .tuner.utils import load_adapters
@ -35,7 +35,8 @@ MODEL_REMAPPING = {
MAX_FILE_SIZE_GB = 5 MAX_FILE_SIZE_GB = 5
# A stream on the default device just for generation # A stream on the default device just for generation
generation_stream = mx.new_stream(mx.default_device()) # generation_stream = mx.new_stream(mx.default_device())
generation_stream = mx.default_stream(mx.default_device())
class ModelNotFoundError(Exception): class ModelNotFoundError(Exception):
@ -155,10 +156,16 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
def generate_step( def generate_step(
prompt: mx.array, prompt: mx.array,
model: nn.Module, model: nn.Module,
temp: float = 0.0,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
top_p: float = 1.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
prefill_step_size: int = 512, prefill_step_size: int = 512,
max_kv_size: Optional[int] = None, max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None, prompt_cache: Optional[Any] = None,
sampler: Optional[Callable[mx.array, mx.array]] = None, logit_bias: Optional[Dict[int, float]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
kv_bits: Optional[int] = None, kv_bits: Optional[int] = None,
kv_group_size: int = 64, kv_group_size: int = 64,
@ -170,14 +177,24 @@ def generate_step(
Args: Args:
prompt (mx.array): The input prompt. prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
prefill_step_size (int): Step size for processing the prompt. prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten. entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place. provided, the cache will be updated in place.
sampler (Callable[mx.array, mx.array], optional). A function which logit_bias (dictionary, optional): Additive logit bias.
takes log probabilities and returns tokens. If ``None`` then the
argmax is used. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed A list of functions that take tokens and logits and return the processed
logits. Default: ``None``. logits. Default: ``None``.
@ -204,7 +221,11 @@ def generate_step(
elif len(prompt_cache) != len(model.layers): elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.") raise ValueError("Wrong number of layers in the prompt cache.")
sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
logits_processors = logits_processors or []
logits_processors.extend(
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
)
def _step(y): def _step(y):
with mx.stream(generation_stream): with mx.stream(generation_stream):
@ -222,7 +243,7 @@ def generate_step(
prompt_cache, quantized_kv_start, kv_group_size, kv_bits prompt_cache, quantized_kv_start, kv_group_size, kv_bits
) )
logprobs = logits - mx.logsumexp(logits) logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs) y = sampler(logprobs)
return y, logprobs.squeeze(0) return y, logprobs.squeeze(0)
@ -249,7 +270,7 @@ def generate_step(
def stream_generate( def stream_generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str, prompt: Union[str, List[int]],
max_tokens: int = 100, max_tokens: int = 100,
**kwargs, **kwargs,
) -> Union[str, Generator[str, None, None]]: ) -> Union[str, Generator[str, None, None]]:
@ -257,7 +278,7 @@ def stream_generate(
A generator producing text based on the given prompt from the model. A generator producing text based on the given prompt from the model.
Args: Args:
prompt (mx.array): The input prompt. prompt (Union[str, List[int]]): The input prompt.
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
max_tokens (int): The ma max_tokens (int): The ma
kwargs: The remaining options get passed to :func:`generate_step`. kwargs: The remaining options get passed to :func:`generate_step`.
@ -269,23 +290,26 @@ def stream_generate(
if not isinstance(tokenizer, TokenizerWrapper): if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer) tokenizer = TokenizerWrapper(tokenizer)
prompt_tokens = mx.array(tokenizer.encode(prompt)) prompt_tokens = mx.array(
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
)
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
detokenizer.reset() with wired_limit(model, [generation_stream]):
for n, (token, _) in zip( detokenizer.reset()
range(max_tokens), for n, (token, logits) in zip(
generate_step(prompt_tokens, model, **kwargs), range(max_tokens),
): generate_step(prompt_tokens, model, **kwargs),
if token == tokenizer.eos_token_id: ):
break if token == tokenizer.eos_token_id:
detokenizer.add_token(token) break
detokenizer.add_token(token)
# Yield the last segment if streaming # Yield the last segment if streaming
yield detokenizer.last_segment yield detokenizer.last_segment, token, logits
detokenizer.finalize() detokenizer.finalize()
yield detokenizer.last_segment yield detokenizer.last_segment, token, logits
def generate( def generate(
@ -322,7 +346,7 @@ def generate(
prompt_tokens = mx.array(tokenizer.encode(prompt)) prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
with wired_limit(model): with wired_limit(model, [generation_stream]):
tic = time.perf_counter() tic = time.perf_counter()
detokenizer.reset() detokenizer.reset()
for n, (token, logprobs) in zip( for n, (token, logprobs) in zip(
@ -361,7 +385,7 @@ def generate(
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
) )
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30 peak_mem = mx.metal.get_peak_memory() / 1e9
print(f"Peak memory: {peak_mem:.3f} GB") print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text return detokenizer.text

View File

@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase):
"hello", "hello",
max_tokens=5, max_tokens=5,
verbose=False, verbose=False,
logits_processor=[logits_processor], logits_processors=[logits_processor],
) )
self.assertEqual(len(all_toks), len(init_toks) + 5) self.assertEqual(len(all_toks), len(init_toks) + 5)

View File

@ -299,7 +299,7 @@ class TestPromptCache(unittest.TestCase):
): ):
i += 1 i += 1
self.assertEqual(tok, toks[i]) self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=1e-2)) self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2))
if __name__ == "__main__": if __name__ == "__main__":