mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
refactor sampler/processor and a few improvements
This commit is contained in:
parent
3783156072
commit
0be87b3c53
@ -101,7 +101,8 @@ To see a description of all the arguments you can do:
|
|||||||
#### Streaming
|
#### Streaming
|
||||||
|
|
||||||
For streaming generation, use the `stream_generate` function. This returns a
|
For streaming generation, use the `stream_generate` function. This returns a
|
||||||
generator object which streams the output text. For example,
|
generator object which streams the output text, token, and log probabilities.
|
||||||
|
For example,
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mlx_lm import load, stream_generate
|
from mlx_lm import load, stream_generate
|
||||||
@ -116,7 +117,7 @@ prompt = tokenizer.apply_chat_template(
|
|||||||
messages, tokenize=False, add_generation_prompt=True
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
)
|
)
|
||||||
|
|
||||||
for t in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
|
||||||
print(t, end="", flush=True)
|
print(t, end="", flush=True)
|
||||||
print()
|
print()
|
||||||
```
|
```
|
||||||
|
@ -152,6 +152,7 @@ def main():
|
|||||||
|
|
||||||
model(y[:step_size][None], cache=cache)
|
model(y[:step_size][None], cache=cache)
|
||||||
mx.eval([c.state for c in cache])
|
mx.eval([c.state for c in cache])
|
||||||
|
mx.metal.clear_cache()
|
||||||
processed += min(y.size, step_size)
|
processed += min(y.size, step_size)
|
||||||
y = y[step_size:]
|
y = y[step_size:]
|
||||||
current = time.time()
|
current = time.time()
|
||||||
@ -165,14 +166,13 @@ def main():
|
|||||||
)
|
)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
|
||||||
|
|
||||||
print("Saving...")
|
print("Saving...")
|
||||||
metadata = {}
|
metadata = {}
|
||||||
metadata["model"] = args.model
|
metadata["model"] = args.model
|
||||||
metadata["chat_template"] = tokenizer.chat_template
|
metadata["chat_template"] = tokenizer.chat_template
|
||||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
|
||||||
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,6 +13,8 @@ DEFAULT_PROMPT = "hello"
|
|||||||
DEFAULT_MAX_TOKENS = 100
|
DEFAULT_MAX_TOKENS = 100
|
||||||
DEFAULT_TEMP = 0.0
|
DEFAULT_TEMP = 0.0
|
||||||
DEFAULT_TOP_P = 1.0
|
DEFAULT_TOP_P = 1.0
|
||||||
|
DEFAULT_MIN_P = 0.0
|
||||||
|
DEFAULT_MIN_TOKENS_TO_KEEP = 1
|
||||||
DEFAULT_SEED = 0
|
DEFAULT_SEED = 0
|
||||||
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
|
||||||
DEFAULT_QUANTIZED_KV_START = 5000
|
DEFAULT_QUANTIZED_KV_START = 5000
|
||||||
@ -52,6 +54,7 @@ def setup_arg_parser():
|
|||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--prompt",
|
"--prompt",
|
||||||
|
"-p",
|
||||||
default=DEFAULT_PROMPT,
|
default=DEFAULT_PROMPT,
|
||||||
help="Message to be processed by the model ('-' reads from stdin)",
|
help="Message to be processed by the model ('-' reads from stdin)",
|
||||||
)
|
)
|
||||||
@ -68,6 +71,15 @@ def setup_arg_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--min-tokens-to-keep",
|
||||||
|
type=float,
|
||||||
|
default=DEFAULT_MIN_TOKENS_TO_KEEP,
|
||||||
|
help="Minimum tokens to keep for min-p sampling.",
|
||||||
|
)
|
||||||
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ignore-chat-template",
|
"--ignore-chat-template",
|
||||||
@ -238,8 +250,6 @@ def main():
|
|||||||
raise ValueError("Cannot use --colorize with --verbose=False")
|
raise ValueError("Cannot use --colorize with --verbose=False")
|
||||||
formatter = colorprint_by_t0 if args.colorize else None
|
formatter = colorprint_by_t0 if args.colorize else None
|
||||||
|
|
||||||
sampler = make_sampler(
|
|
||||||
args.temp, args.top_p, args.min_p, args.min_tokens_to_keep)
|
|
||||||
response = generate(
|
response = generate(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -247,7 +257,10 @@ def main():
|
|||||||
args.max_tokens,
|
args.max_tokens,
|
||||||
verbose=args.verbose,
|
verbose=args.verbose,
|
||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
sampler=sampler,
|
temp=args.temp,
|
||||||
|
top_p=args.top_p,
|
||||||
|
min_p=args.min_p,
|
||||||
|
min_tokens_to_keep=args.min_tokens_to_keep,
|
||||||
max_kv_size=args.max_kv_size,
|
max_kv_size=args.max_kv_size,
|
||||||
prompt_cache=prompt_cache if using_cache else None,
|
prompt_cache=prompt_cache if using_cache else None,
|
||||||
kv_bits=args.kv_bits,
|
kv_bits=args.kv_bits,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Callable, Dict, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
@ -25,7 +26,7 @@ def make_sampler(
|
|||||||
be filtered by min_p sampling.
|
be filtered by min_p sampling.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Callabel[mx.array, mx.array]:
|
Callable[mx.array, mx.array]:
|
||||||
A sampler which takes log-probabilities and returns tokens.
|
A sampler which takes log-probabilities and returns tokens.
|
||||||
"""
|
"""
|
||||||
if temp == 0:
|
if temp == 0:
|
||||||
@ -38,7 +39,11 @@ def make_sampler(
|
|||||||
return lambda x: categorical_sampling(x, temp)
|
return lambda x: categorical_sampling(x, temp)
|
||||||
|
|
||||||
|
|
||||||
def make_logits_processors():
|
def make_logits_processors(
|
||||||
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
repetition_context_size: Optional[int] = 20,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Make logits processors for use with ``generate_step``.
|
Make logits processors for use with ``generate_step``.
|
||||||
|
|
||||||
@ -48,8 +53,13 @@ def make_logits_processors():
|
|||||||
repetition_context_size (int, optional): The number of tokens to
|
repetition_context_size (int, optional): The number of tokens to
|
||||||
consider for repetition penalty. Default: ``20``.
|
consider for repetition penalty. Default: ``20``.
|
||||||
logit_bias (dictionary, optional): Additive logit bias.
|
logit_bias (dictionary, optional): Additive logit bias.
|
||||||
"""
|
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Callable[[mx.array, mx.array], mx.array]]:
|
||||||
|
A list of logits processors. Each processor in the list is a
|
||||||
|
callable which takes an array of tokens and an array of logits
|
||||||
|
and returns the updated logits.
|
||||||
|
"""
|
||||||
logits_processors = []
|
logits_processors = []
|
||||||
if logit_bias:
|
if logit_bias:
|
||||||
indices = mx.array(list(logit_bias.keys()))
|
indices = mx.array(list(logit_bias.keys()))
|
||||||
@ -61,6 +71,12 @@ def make_logits_processors():
|
|||||||
|
|
||||||
logits_processors.append(logit_bias_processor)
|
logits_processors.append(logit_bias_processor)
|
||||||
|
|
||||||
|
if repetition_penalty and repetition_penalty != 0.0:
|
||||||
|
logits_processors.append(
|
||||||
|
make_repetition_penalty(repetition_penalty, repetition_context_size)
|
||||||
|
)
|
||||||
|
return logits_processors
|
||||||
|
|
||||||
|
|
||||||
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
|
||||||
def min_p_sampling(
|
def min_p_sampling(
|
||||||
@ -159,7 +175,7 @@ def categorical_sampling(logits, temp):
|
|||||||
return mx.random.categorical(logits * (1 / temp))
|
return mx.random.categorical(logits * (1 / temp))
|
||||||
|
|
||||||
|
|
||||||
def repetition_penalty(penalty: float, context_size: int = 20):
|
def make_repetition_penalty(penalty: float, context_size: int = 20):
|
||||||
"""
|
"""
|
||||||
Make repetition penalty processor.
|
Make repetition penalty processor.
|
||||||
|
|
||||||
@ -177,7 +193,7 @@ def repetition_penalty(penalty: float, context_size: int = 20):
|
|||||||
if penalty < 0 or not isinstance(penalty, float):
|
if penalty < 0 or not isinstance(penalty, float):
|
||||||
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
|
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
|
||||||
|
|
||||||
def repetition_penalty_processor(logits, tokens):
|
def repetition_penalty_processor(tokens, logits):
|
||||||
if len(tokens) > 0:
|
if len(tokens) > 0:
|
||||||
tokens = tokens[-context_size:]
|
tokens = tokens[-context_size:]
|
||||||
selected_logits = logits[:, tokens]
|
selected_logits = logits[:, tokens]
|
||||||
|
@ -27,7 +27,7 @@ from huggingface_hub import scan_cache_dir
|
|||||||
|
|
||||||
from ._version import __version__
|
from ._version import __version__
|
||||||
from .models.cache import make_prompt_cache
|
from .models.cache import make_prompt_cache
|
||||||
from .utils import generate_step, load
|
from .utils import load, stream_generate
|
||||||
|
|
||||||
|
|
||||||
def get_system_fingerprint():
|
def get_system_fingerprint():
|
||||||
@ -290,10 +290,7 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
# Call endpoint specific method
|
# Call endpoint specific method
|
||||||
prompt = endpoints[self.path]()
|
prompt = endpoints[self.path]()
|
||||||
|
self.handle_completion(prompt, stop_id_sequences)
|
||||||
# Call method based on response type
|
|
||||||
method = self.handle_stream if self.stream else self.handle_completion
|
|
||||||
method(prompt, stop_id_sequences)
|
|
||||||
|
|
||||||
def validate_model_parameters(self):
|
def validate_model_parameters(self):
|
||||||
"""
|
"""
|
||||||
@ -452,25 +449,28 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
stop_id_sequences (List[List[int]]): A list of stop words passed
|
stop_id_sequences (List[List[int]]): A list of stop words passed
|
||||||
to the stopping_criteria function
|
to the stopping_criteria function
|
||||||
"""
|
"""
|
||||||
detokenizer = self.tokenizer.detokenizer
|
|
||||||
detokenizer.reset()
|
|
||||||
tokens = []
|
tokens = []
|
||||||
finish_reason = "length"
|
finish_reason = "length"
|
||||||
stop_sequence_suffix = None
|
stop_sequence_suffix = None
|
||||||
logging.debug(f"Starting completion:")
|
if self.stream:
|
||||||
|
self.end_headers()
|
||||||
|
logging.debug(f"Starting stream:")
|
||||||
|
else:
|
||||||
|
logging.debug(f"Starting completion:")
|
||||||
token_logprobs = []
|
token_logprobs = []
|
||||||
top_tokens = []
|
top_tokens = []
|
||||||
|
|
||||||
prompt = mx.array(self.get_prompt_cache(prompt))
|
prompt = self.get_prompt_cache(prompt)
|
||||||
|
|
||||||
|
text = ""
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
for _, (token, logprobs) in zip(
|
for n, (segment, token, logprobs) in enumerate(
|
||||||
range(self.max_tokens),
|
stream_generate(
|
||||||
generate_step(
|
|
||||||
prompt=prompt,
|
|
||||||
model=self.model,
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
prompt=prompt,
|
||||||
|
max_tokens=self.max_tokens,
|
||||||
temp=self.temperature,
|
temp=self.temperature,
|
||||||
top_p=self.top_p,
|
|
||||||
repetition_penalty=self.repetition_penalty,
|
repetition_penalty=self.repetition_penalty,
|
||||||
repetition_context_size=self.repetition_context_size,
|
repetition_context_size=self.repetition_context_size,
|
||||||
logit_bias=self.logit_bias,
|
logit_bias=self.logit_bias,
|
||||||
@ -481,8 +481,8 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
prompt_time = time.perf_counter() - tic
|
prompt_time = time.perf_counter() - tic
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
|
|
||||||
detokenizer.add_token(token)
|
text += segment
|
||||||
logging.debug(detokenizer.text)
|
logging.debug(text)
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
|
|
||||||
if self.logprobs > 0:
|
if self.logprobs > 0:
|
||||||
@ -503,128 +503,63 @@ class APIHandler(BaseHTTPRequestHandler):
|
|||||||
stop_sequence_suffix = self.tokenizer.decode(
|
stop_sequence_suffix = self.tokenizer.decode(
|
||||||
tokens[-stop_condition.trim_length :]
|
tokens[-stop_condition.trim_length :]
|
||||||
)
|
)
|
||||||
|
text = text[: -len(stop_sequence_suffix)]
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if self.stream:
|
||||||
|
# If the end of tokens overlaps with a stop sequence, generate new
|
||||||
|
# tokens until we know if the stop sequence is hit or not
|
||||||
|
if any(
|
||||||
|
(
|
||||||
|
sequence_overlap(tokens, sequence)
|
||||||
|
for sequence in stop_id_sequences
|
||||||
|
)
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
elif segment:
|
||||||
|
response = self.generate_response(segment, None)
|
||||||
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
|
self.wfile.flush()
|
||||||
|
|
||||||
self.prompt_cache.tokens.extend(tokens)
|
self.prompt_cache.tokens.extend(tokens)
|
||||||
detokenizer.finalize()
|
|
||||||
text = (
|
|
||||||
detokenizer.text
|
|
||||||
if stop_sequence_suffix is None
|
|
||||||
else detokenizer.text[: -len(stop_sequence_suffix)]
|
|
||||||
)
|
|
||||||
gen_time = time.perf_counter() - tic
|
gen_time = time.perf_counter() - tic
|
||||||
prompt_tps = len(prompt) / prompt_time
|
prompt_tps = len(prompt) / prompt_time
|
||||||
gen_tps = len(tokens) / gen_time
|
gen_tps = len(tokens) / gen_time
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||||
response = self.generate_response(
|
|
||||||
text,
|
|
||||||
finish_reason,
|
|
||||||
len(prompt),
|
|
||||||
len(tokens),
|
|
||||||
token_logprobs=token_logprobs,
|
|
||||||
top_tokens=top_tokens,
|
|
||||||
tokens=tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
|
||||||
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
|
||||||
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
|
logging.debug(f"Peak memory: {peak_mem:.3f} GB")
|
||||||
response_json = json.dumps(response).encode()
|
|
||||||
indent = "\t" # Backslashes can't be inside of f-strings
|
|
||||||
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
|
|
||||||
|
|
||||||
# Send an additional Content-Length header when it is known
|
if self.stream:
|
||||||
self.send_header("Content-Length", str(len(response_json)))
|
response = self.generate_response(segment, finish_reason)
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
self.wfile.write(response_json)
|
|
||||||
self.wfile.flush()
|
|
||||||
|
|
||||||
def handle_stream(
|
|
||||||
self,
|
|
||||||
prompt: List[int],
|
|
||||||
stop_id_sequences: List[List[int]],
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Generate response to prompt and foward it to the client using a Server
|
|
||||||
Sent Events (SSE) stream.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (mx.array): The tokenized prompt
|
|
||||||
stop_id_sequences (List[List[int]]): A list of stop words passed to
|
|
||||||
the stopping_criteria function
|
|
||||||
"""
|
|
||||||
# No additional headers are needed, call end_headers
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
detokenizer = self.tokenizer.detokenizer
|
|
||||||
detokenizer.reset()
|
|
||||||
tokens = []
|
|
||||||
|
|
||||||
stop_sequence_suffix = None
|
|
||||||
logging.debug(f"Starting stream:")
|
|
||||||
|
|
||||||
prompt = mx.array(self.get_prompt_cache(prompt))
|
|
||||||
|
|
||||||
for _, (token, _) in zip(
|
|
||||||
range(self.max_tokens),
|
|
||||||
generate_step(
|
|
||||||
prompt=prompt,
|
|
||||||
model=self.model,
|
|
||||||
temp=self.temperature,
|
|
||||||
top_p=self.top_p,
|
|
||||||
repetition_penalty=self.repetition_penalty,
|
|
||||||
repetition_context_size=self.repetition_context_size,
|
|
||||||
prompt_cache=self.prompt_cache.cache,
|
|
||||||
),
|
|
||||||
):
|
|
||||||
detokenizer.add_token(token)
|
|
||||||
logging.debug(detokenizer.text)
|
|
||||||
tokens.append(token)
|
|
||||||
|
|
||||||
stop_condition = stopping_criteria(
|
|
||||||
tokens,
|
|
||||||
stop_id_sequences,
|
|
||||||
self.tokenizer.eos_token_id,
|
|
||||||
)
|
|
||||||
if stop_condition.stop_met:
|
|
||||||
if stop_condition.trim_length:
|
|
||||||
stop_sequence_suffix = self.tokenizer.decode(
|
|
||||||
tokens[-stop_condition.trim_length :]
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
# If the end of tokens overlaps with a stop sequence, generate new
|
|
||||||
# tokens until we know if the stop sequence is hit or not
|
|
||||||
if any(
|
|
||||||
(sequence_overlap(tokens, sequence) for sequence in stop_id_sequences)
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
new_text = detokenizer.last_segment
|
|
||||||
if new_text:
|
|
||||||
response = self.generate_response(new_text, None)
|
|
||||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
|
||||||
self.wfile.flush()
|
|
||||||
|
|
||||||
self.prompt_cache.tokens.extend(tokens)
|
|
||||||
|
|
||||||
# check is there any remaining text to send
|
|
||||||
detokenizer.finalize()
|
|
||||||
last_segment = detokenizer.last_segment
|
|
||||||
if last_segment:
|
|
||||||
if stop_sequence_suffix is not None:
|
|
||||||
last_segment = last_segment[: -len(stop_sequence_suffix)]
|
|
||||||
response = self.generate_response(last_segment, "length")
|
|
||||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
self.wfile.flush()
|
self.wfile.flush()
|
||||||
|
if self.stream_options is not None and self.stream_options["include_usage"]:
|
||||||
|
response = self.completion_usage_response(len(prompt), len(tokens))
|
||||||
|
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
||||||
|
self.wfile.flush()
|
||||||
|
self.wfile.write("data: [DONE]\n\n".encode())
|
||||||
|
self.wfile.flush()
|
||||||
|
else:
|
||||||
|
response = self.generate_response(
|
||||||
|
text,
|
||||||
|
finish_reason,
|
||||||
|
len(prompt),
|
||||||
|
len(tokens),
|
||||||
|
token_logprobs=token_logprobs,
|
||||||
|
top_tokens=top_tokens,
|
||||||
|
tokens=tokens,
|
||||||
|
)
|
||||||
|
response_json = json.dumps(response).encode()
|
||||||
|
indent = "\t" # Backslashes can't be inside of f-strings
|
||||||
|
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
|
||||||
|
|
||||||
if self.stream_options is not None and self.stream_options["include_usage"]:
|
# Send an additional Content-Length header when it is known
|
||||||
response = self.completion_usage_response(len(prompt), len(tokens))
|
self.send_header("Content-Length", str(len(response_json)))
|
||||||
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
|
self.end_headers()
|
||||||
|
self.wfile.write(response_json)
|
||||||
self.wfile.write("data: [DONE]\n\n".encode())
|
self.wfile.flush()
|
||||||
self.wfile.flush()
|
|
||||||
|
|
||||||
def completion_usage_response(
|
def completion_usage_response(
|
||||||
self,
|
self,
|
||||||
|
@ -285,7 +285,7 @@ def train(
|
|||||||
it_sec = args.steps_per_report / (stop - start)
|
it_sec = args.steps_per_report / (stop - start)
|
||||||
tokens_sec = float(n_tokens) / (stop - start)
|
tokens_sec = float(n_tokens) / (stop - start)
|
||||||
trained_tokens += n_tokens
|
trained_tokens += n_tokens
|
||||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
print(
|
print(
|
||||||
f"Iter {it}: Train loss {train_loss:.3f}, "
|
f"Iter {it}: Train loss {train_loss:.3f}, "
|
||||||
|
@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer
|
|||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .models import cache
|
from .models import cache
|
||||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
from .sample_utils import make_logits_processors, make_sampler
|
||||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||||
from .tuner.utils import dequantize as dequantize_model
|
from .tuner.utils import dequantize as dequantize_model
|
||||||
from .tuner.utils import load_adapters
|
from .tuner.utils import load_adapters
|
||||||
@ -35,7 +35,8 @@ MODEL_REMAPPING = {
|
|||||||
MAX_FILE_SIZE_GB = 5
|
MAX_FILE_SIZE_GB = 5
|
||||||
|
|
||||||
# A stream on the default device just for generation
|
# A stream on the default device just for generation
|
||||||
generation_stream = mx.new_stream(mx.default_device())
|
# generation_stream = mx.new_stream(mx.default_device())
|
||||||
|
generation_stream = mx.default_stream(mx.default_device())
|
||||||
|
|
||||||
|
|
||||||
class ModelNotFoundError(Exception):
|
class ModelNotFoundError(Exception):
|
||||||
@ -155,10 +156,16 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
|
|||||||
def generate_step(
|
def generate_step(
|
||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
|
temp: float = 0.0,
|
||||||
|
repetition_penalty: Optional[float] = None,
|
||||||
|
repetition_context_size: Optional[int] = 20,
|
||||||
|
top_p: float = 1.0,
|
||||||
|
min_p: float = 0.0,
|
||||||
|
min_tokens_to_keep: int = 1,
|
||||||
prefill_step_size: int = 512,
|
prefill_step_size: int = 512,
|
||||||
max_kv_size: Optional[int] = None,
|
max_kv_size: Optional[int] = None,
|
||||||
prompt_cache: Optional[Any] = None,
|
prompt_cache: Optional[Any] = None,
|
||||||
sampler: Optional[Callable[mx.array, mx.array]] = None,
|
logit_bias: Optional[Dict[int, float]] = None,
|
||||||
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||||
kv_bits: Optional[int] = None,
|
kv_bits: Optional[int] = None,
|
||||||
kv_group_size: int = 64,
|
kv_group_size: int = 64,
|
||||||
@ -170,14 +177,24 @@ def generate_step(
|
|||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The input prompt.
|
prompt (mx.array): The input prompt.
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
|
temp (float): The temperature for sampling, if 0 the argmax is used.
|
||||||
|
Default: ``0``.
|
||||||
|
repetition_penalty (float, optional): The penalty factor for repeating
|
||||||
|
tokens.
|
||||||
|
repetition_context_size (int, optional): The number of tokens to
|
||||||
|
consider for repetition penalty. Default: ``20``.
|
||||||
|
top_p (float, optional): Nulceus sampling, higher means model considers
|
||||||
|
more less likely words.
|
||||||
|
min_p (float, optional): The minimum value (scaled by the top token's
|
||||||
|
probability) that a token probability must have to be considered.
|
||||||
|
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
|
||||||
|
be filtered by min_p sampling.
|
||||||
prefill_step_size (int): Step size for processing the prompt.
|
prefill_step_size (int): Step size for processing the prompt.
|
||||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||||
entries (except the first 4 tokens) will be overwritten.
|
entries (except the first 4 tokens) will be overwritten.
|
||||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||||
provided, the cache will be updated in place.
|
provided, the cache will be updated in place.
|
||||||
sampler (Callable[mx.array, mx.array], optional). A function which
|
logit_bias (dictionary, optional): Additive logit bias.
|
||||||
takes log probabilities and returns tokens. If ``None`` then the
|
|
||||||
argmax is used. Default: ``None``.
|
|
||||||
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||||
A list of functions that take tokens and logits and return the processed
|
A list of functions that take tokens and logits and return the processed
|
||||||
logits. Default: ``None``.
|
logits. Default: ``None``.
|
||||||
@ -204,7 +221,11 @@ def generate_step(
|
|||||||
elif len(prompt_cache) != len(model.layers):
|
elif len(prompt_cache) != len(model.layers):
|
||||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||||
|
|
||||||
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
|
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
|
||||||
|
logits_processors = logits_processors or []
|
||||||
|
logits_processors.extend(
|
||||||
|
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
|
||||||
|
)
|
||||||
|
|
||||||
def _step(y):
|
def _step(y):
|
||||||
with mx.stream(generation_stream):
|
with mx.stream(generation_stream):
|
||||||
@ -222,7 +243,7 @@ def generate_step(
|
|||||||
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
prompt_cache, quantized_kv_start, kv_group_size, kv_bits
|
||||||
)
|
)
|
||||||
|
|
||||||
logprobs = logits - mx.logsumexp(logits)
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
||||||
y = sampler(logprobs)
|
y = sampler(logprobs)
|
||||||
return y, logprobs.squeeze(0)
|
return y, logprobs.squeeze(0)
|
||||||
|
|
||||||
@ -249,7 +270,7 @@ def generate_step(
|
|||||||
def stream_generate(
|
def stream_generate(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
|
||||||
prompt: str,
|
prompt: Union[str, List[int]],
|
||||||
max_tokens: int = 100,
|
max_tokens: int = 100,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Union[str, Generator[str, None, None]]:
|
) -> Union[str, Generator[str, None, None]]:
|
||||||
@ -257,7 +278,7 @@ def stream_generate(
|
|||||||
A generator producing text based on the given prompt from the model.
|
A generator producing text based on the given prompt from the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
prompt (mx.array): The input prompt.
|
prompt (Union[str, List[int]]): The input prompt.
|
||||||
model (nn.Module): The model to use for generation.
|
model (nn.Module): The model to use for generation.
|
||||||
max_tokens (int): The ma
|
max_tokens (int): The ma
|
||||||
kwargs: The remaining options get passed to :func:`generate_step`.
|
kwargs: The remaining options get passed to :func:`generate_step`.
|
||||||
@ -269,23 +290,26 @@ def stream_generate(
|
|||||||
if not isinstance(tokenizer, TokenizerWrapper):
|
if not isinstance(tokenizer, TokenizerWrapper):
|
||||||
tokenizer = TokenizerWrapper(tokenizer)
|
tokenizer = TokenizerWrapper(tokenizer)
|
||||||
|
|
||||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
prompt_tokens = mx.array(
|
||||||
|
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
|
||||||
|
)
|
||||||
detokenizer = tokenizer.detokenizer
|
detokenizer = tokenizer.detokenizer
|
||||||
|
|
||||||
detokenizer.reset()
|
with wired_limit(model, [generation_stream]):
|
||||||
for n, (token, _) in zip(
|
detokenizer.reset()
|
||||||
range(max_tokens),
|
for n, (token, logits) in zip(
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
range(max_tokens),
|
||||||
):
|
generate_step(prompt_tokens, model, **kwargs),
|
||||||
if token == tokenizer.eos_token_id:
|
):
|
||||||
break
|
if token == tokenizer.eos_token_id:
|
||||||
detokenizer.add_token(token)
|
break
|
||||||
|
detokenizer.add_token(token)
|
||||||
|
|
||||||
# Yield the last segment if streaming
|
# Yield the last segment if streaming
|
||||||
yield detokenizer.last_segment
|
yield detokenizer.last_segment, token, logits
|
||||||
|
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
yield detokenizer.last_segment
|
yield detokenizer.last_segment, token, logits
|
||||||
|
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
@ -322,7 +346,7 @@ def generate(
|
|||||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||||
detokenizer = tokenizer.detokenizer
|
detokenizer = tokenizer.detokenizer
|
||||||
|
|
||||||
with wired_limit(model):
|
with wired_limit(model, [generation_stream]):
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
for n, (token, logprobs) in zip(
|
for n, (token, logprobs) in zip(
|
||||||
@ -361,7 +385,7 @@ def generate(
|
|||||||
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
|
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
|
||||||
)
|
)
|
||||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||||
|
|
||||||
return detokenizer.text
|
return detokenizer.text
|
||||||
|
@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase):
|
|||||||
"hello",
|
"hello",
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
logits_processor=[logits_processor],
|
logits_processors=[logits_processor],
|
||||||
)
|
)
|
||||||
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
||||||
|
|
||||||
|
@ -299,7 +299,7 @@ class TestPromptCache(unittest.TestCase):
|
|||||||
):
|
):
|
||||||
i += 1
|
i += 1
|
||||||
self.assertEqual(tok, toks[i])
|
self.assertEqual(tok, toks[i])
|
||||||
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=1e-2))
|
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
Reference in New Issue
Block a user