Merge remote-tracking branch 'origin/completion_only' into completion_only

This commit is contained in:
Chime Ogbuji 2024-11-10 09:54:49 -05:00
commit 791727fa1c
14 changed files with 294 additions and 245 deletions

View File

@ -188,7 +188,7 @@ The adapters are saved in `mlx_output` and can be used directly by the
```shell ```shell
python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \ python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \
--adapter mlx_output/0001200_adapters.safetensors \ --adapter mlx_output/final_adapters.safetensors \
--fuse-adapter \ --fuse-adapter \
--no-t5-padding \ --no-t5-padding \
'A photo of an sks dog lying on the sand at a beach in Greece' 'A photo of an sks dog lying on the sand at a beach in Greece'

View File

@ -13,7 +13,7 @@ from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image from PIL import Image
from flux import FluxPipeline, Trainer, load_dataset from flux import FluxPipeline, Trainer, load_dataset, save_config
def generate_progress_images(iteration, flux, args): def generate_progress_images(iteration, flux, args):
@ -43,10 +43,10 @@ def generate_progress_images(iteration, flux, args):
im.save(out_file) im.save(out_file)
def save_adapters(iteration, flux, args): def save_adapters(adapter_name, flux, args):
out_dir = Path(args.output_dir) out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True) out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_adapters.safetensors" out_file = out_dir / adapter_name
print(f"Saving {str(out_file)}") print(f"Saving {str(out_file)}")
mx.save_safetensors( mx.save_safetensors(
@ -157,6 +157,10 @@ if __name__ == "__main__":
parser = setup_arg_parser() parser = setup_arg_parser()
args = parser.parse_args() args = parser.parse_args()
output_path = Path(args.output_dir)
output_path.mkdir(parents=True, exist_ok=True)
save_config(vars(args), output_path / "adapter_config.json")
# Load the model and set it up for LoRA training. We use the same random # Load the model and set it up for LoRA training. We use the same random
# state when creating the LoRA layers so all workers will have the same # state when creating the LoRA layers so all workers will have the same
# initial weights. # initial weights.
@ -278,8 +282,11 @@ if __name__ == "__main__":
generate_progress_images(i + 1, flux, args) generate_progress_images(i + 1, flux, args)
if (i + 1) % args.checkpoint_every == 0: if (i + 1) % args.checkpoint_every == 0:
save_adapters(i + 1, flux, args) save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
if (i + 1) % 10 == 0: if (i + 1) % 10 == 0:
losses = [] losses = []
tic = time.time() tic = time.time()
save_adapters("final_adapters.safetensors", flux, args)
print(f"Training successful. Saved final weights to {args.adapter_file}.")

View File

@ -12,4 +12,5 @@ from .utils import (
load_flow_model, load_flow_model,
load_t5, load_t5,
load_t5_tokenizer, load_t5_tokenizer,
save_config,
) )

View File

@ -3,7 +3,8 @@
import json import json
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from pathlib import Path
from typing import Optional, Union
import mlx.core as mx import mlx.core as mx
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
@ -207,3 +208,23 @@ def load_clip_tokenizer(name: str):
def load_t5_tokenizer(name: str, pad: bool = True): def load_t5_tokenizer(name: str, pad: bool = True):
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model") model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
return T5Tokenizer(model_file, 256 if "schnell" in name else 512) return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
def save_config(
config: dict,
config_path: Union[str, Path],
) -> None:
"""Save the model configuration to the ``config_path``.
The final configuration will be sorted before saving for better readability.
Args:
config (dict): The model configuration.
config_path (Union[str, Path]): Model configuration file path.
"""
# Sort the config for better readability
config = dict(sorted(config.items()))
# Write the config to the provided file
with open(config_path, "w") as fid:
json.dump(config, fid, indent=4)

View File

@ -101,7 +101,8 @@ To see a description of all the arguments you can do:
#### Streaming #### Streaming
For streaming generation, use the `stream_generate` function. This returns a For streaming generation, use the `stream_generate` function. This returns a
generator object which streams the output text. For example, generator object which streams the output text, token, and log probabilities.
For example,
```python ```python
from mlx_lm import load, stream_generate from mlx_lm import load, stream_generate
@ -116,7 +117,7 @@ prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
for t in stream_generate(model, tokenizer, prompt, max_tokens=512): for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512):
print(t, end="", flush=True) print(t, end="", flush=True)
print() print()
``` ```

View File

@ -152,6 +152,7 @@ def main():
model(y[:step_size][None], cache=cache) model(y[:step_size][None], cache=cache)
mx.eval([c.state for c in cache]) mx.eval([c.state for c in cache])
mx.metal.clear_cache()
processed += min(y.size, step_size) processed += min(y.size, step_size)
y = y[step_size:] y = y[step_size:]
current = time.time() current = time.time()
@ -165,14 +166,13 @@ def main():
) )
print() print()
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB")
print("Saving...") print("Saving...")
metadata = {} metadata = {}
metadata["model"] = args.model metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config) metadata["tokenizer_config"] = json.dumps(tokenizer_config)
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
save_prompt_cache(args.prompt_cache_file, cache, metadata) save_prompt_cache(args.prompt_cache_file, cache, metadata)

View File

@ -74,7 +74,7 @@ def main():
prompt = tokenizer.apply_chat_template( prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True messages, tokenize=False, add_generation_prompt=True
) )
for response in stream_generate( for response, *_ in stream_generate(
model, model,
tokenizer, tokenizer,
prompt, prompt,

View File

@ -13,6 +13,8 @@ DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100 DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.0 DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0 DEFAULT_TOP_P = 1.0
DEFAULT_MIN_P = 0.0
DEFAULT_MIN_TOKENS_TO_KEEP = 1
DEFAULT_SEED = 0 DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
DEFAULT_QUANTIZED_KV_START = 5000 DEFAULT_QUANTIZED_KV_START = 5000
@ -52,6 +54,7 @@ def setup_arg_parser():
) )
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
"-p",
default=DEFAULT_PROMPT, default=DEFAULT_PROMPT,
help="Message to be processed by the model ('-' reads from stdin)", help="Message to be processed by the model ('-' reads from stdin)",
) )
@ -68,6 +71,15 @@ def setup_arg_parser():
parser.add_argument( parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
) )
parser.add_argument(
"--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p"
)
parser.add_argument(
"--min-tokens-to-keep",
type=float,
default=DEFAULT_MIN_TOKENS_TO_KEEP,
help="Minimum tokens to keep for min-p sampling.",
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument( parser.add_argument(
"--ignore-chat-template", "--ignore-chat-template",
@ -247,6 +259,8 @@ def main():
formatter=formatter, formatter=formatter,
temp=args.temp, temp=args.temp,
top_p=args.top_p, top_p=args.top_p,
min_p=args.min_p,
min_tokens_to_keep=args.min_tokens_to_keep,
max_kv_size=args.max_kv_size, max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None, prompt_cache=prompt_cache if using_cache else None,
kv_bits=args.kv_bits, kv_bits=args.kv_bits,

View File

@ -1,10 +1,83 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
from functools import partial from functools import partial
from typing import Callable, Dict, Optional
import mlx.core as mx import mlx.core as mx
def make_sampler(
temp: float = 0.0,
top_p: float = 0.0,
min_p: float = 0.0,
min_tokens_to_keep: int = 1,
) -> Callable[mx.array, mx.array]:
"""
Make a sampler function for use with ``generate_step``.
Args:
temp (float): The temperature for sampling, if 0 the argmax is used.
Default: ``0``.
top_p (float, optional): Nulceus sampling, higher means model considers
more less likely words.
min_p (float, optional): The minimum value (scaled by the top token's
probability) that a token probability must have to be considered.
min_tokens_to_keep (int, optional): Minimum number of tokens that cannot
be filtered by min_p sampling.
Returns:
Callable[mx.array, mx.array]:
A sampler which takes log-probabilities and returns tokens.
"""
if temp == 0:
return lambda x: mx.argmax(x, axis=-1)
elif top_p > 0 and top_p < 1.0:
return lambda x: top_p_sampling(x, top_p, temp)
elif min_p != 0.0:
return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp)
else:
return lambda x: categorical_sampling(x, temp)
def make_logits_processors(
logit_bias: Optional[Dict[int, float]] = None,
repetition_penalty: Optional[float] = None,
repetition_context_size: Optional[int] = 20,
):
"""
Make logits processors for use with ``generate_step``.
Args:
repetition_penalty (float, optional): The penalty factor for repeating
tokens.
repetition_context_size (int, optional): The number of tokens to
consider for repetition penalty. Default: ``20``.
logit_bias (dictionary, optional): Additive logit bias.
Returns:
List[Callable[[mx.array, mx.array], mx.array]]:
A list of logits processors. Each processor in the list is a
callable which takes an array of tokens and an array of logits
and returns the updated logits.
"""
logits_processors = []
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
def logit_bias_processor(_, logits):
logits[:, indices] += values
return logits
logits_processors.append(logit_bias_processor)
if repetition_penalty and repetition_penalty != 0.0:
logits_processors.append(
make_repetition_penalty(repetition_penalty, repetition_context_size)
)
return logits_processors
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def min_p_sampling( def min_p_sampling(
logits: mx.array, logits: mx.array,
@ -100,3 +173,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr
@partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state)
def categorical_sampling(logits, temp): def categorical_sampling(logits, temp):
return mx.random.categorical(logits * (1 / temp)) return mx.random.categorical(logits * (1 / temp))
def make_repetition_penalty(penalty: float, context_size: int = 20):
"""
Make repetition penalty processor.
Paper: https://arxiv.org/abs/1909.05858
Args:
penalty (float): The repetition penalty factor to be applied.
context_size (int): The number of previous tokens to use.
Default: ``20``.
Returns:
Callable[[mx.array, List[int]], mx.array]:
The repetition penalty processor.
"""
if penalty < 0 or not isinstance(penalty, float):
raise ValueError(f"penalty must be a non-negative float, got {penalty}")
def repetition_penalty_processor(tokens, logits):
if len(tokens) > 0:
tokens = tokens[-context_size:]
selected_logits = logits[:, tokens]
selected_logits = mx.where(
selected_logits < 0,
selected_logits * penalty,
selected_logits / penalty,
)
logits[:, tokens] = selected_logits
return logits
return repetition_penalty_processor

View File

@ -27,7 +27,7 @@ from huggingface_hub import scan_cache_dir
from ._version import __version__ from ._version import __version__
from .models.cache import make_prompt_cache from .models.cache import make_prompt_cache
from .utils import generate_step, load from .utils import load, stream_generate
def get_system_fingerprint(): def get_system_fingerprint():
@ -64,7 +64,7 @@ def stopping_criteria(
end if it has (`trim_length`). end if it has (`trim_length`).
""" """
if tokens and tokens[-1] == eos_token_id: if tokens and tokens[-1] == eos_token_id:
return StopCondition(stop_met=True, trim_length=1) return StopCondition(stop_met=True, trim_length=0)
for stop_ids in stop_id_sequences: for stop_ids in stop_id_sequences:
if len(tokens) >= len(stop_ids): if len(tokens) >= len(stop_ids):
@ -253,7 +253,7 @@ class APIHandler(BaseHTTPRequestHandler):
self.max_tokens = self.body.get("max_completion_tokens", None) self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None: if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512) self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 1.0) self.temperature = self.body.get("temperature", 0.0)
self.top_p = self.body.get("top_p", 1.0) self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
self.repetition_context_size = self.body.get("repetition_context_size", 20) self.repetition_context_size = self.body.get("repetition_context_size", 20)
@ -290,10 +290,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Call endpoint specific method # Call endpoint specific method
prompt = endpoints[self.path]() prompt = endpoints[self.path]()
self.handle_completion(prompt, stop_id_sequences)
# Call method based on response type
method = self.handle_stream if self.stream else self.handle_completion
method(prompt, stop_id_sequences)
def validate_model_parameters(self): def validate_model_parameters(self):
""" """
@ -452,32 +449,40 @@ class APIHandler(BaseHTTPRequestHandler):
stop_id_sequences (List[List[int]]): A list of stop words passed stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function to the stopping_criteria function
""" """
detokenizer = self.tokenizer.detokenizer
detokenizer.reset()
tokens = [] tokens = []
finish_reason = "length" finish_reason = "length"
stop_sequence_suffix = None stop_sequence_suffix = None
logging.debug(f"Starting completion:") if self.stream:
self.end_headers()
logging.debug(f"Starting stream:")
else:
logging.debug(f"Starting completion:")
token_logprobs = [] token_logprobs = []
top_tokens = [] top_tokens = []
prompt = self.get_prompt_cache(prompt) prompt = self.get_prompt_cache(prompt)
for _, (token, logprobs) in zip( text = ""
range(self.max_tokens), tic = time.perf_counter()
generate_step( for n, (segment, token, logprobs) in enumerate(
prompt=mx.array(prompt), stream_generate(
model=self.model, model=self.model,
tokenizer=self.tokenizer,
prompt=prompt,
max_tokens=self.max_tokens,
temp=self.temperature, temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty, repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size, repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias, logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache, prompt_cache=self.prompt_cache.cache,
), ),
): ):
detokenizer.add_token(token) if n == 0:
logging.debug(detokenizer.text) prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
text += segment
logging.debug(text)
tokens.append(token) tokens.append(token)
if self.logprobs > 0: if self.logprobs > 0:
@ -498,121 +503,63 @@ class APIHandler(BaseHTTPRequestHandler):
stop_sequence_suffix = self.tokenizer.decode( stop_sequence_suffix = self.tokenizer.decode(
tokens[-stop_condition.trim_length :] tokens[-stop_condition.trim_length :]
) )
text = text[: -len(stop_sequence_suffix)]
break break
self.prompt_cache.tokens.extend(tokens) if self.stream:
detokenizer.finalize() # If the end of tokens overlaps with a stop sequence, generate new
text = ( # tokens until we know if the stop sequence is hit or not
detokenizer.text if any(
if stop_sequence_suffix is None (
else detokenizer.text[: -len(stop_sequence_suffix)] sequence_overlap(tokens, sequence)
) for sequence in stop_id_sequences
response = self.generate_response(
text,
finish_reason,
len(prompt),
len(tokens),
token_logprobs=token_logprobs,
top_tokens=top_tokens,
tokens=tokens,
)
response_json = json.dumps(response).encode()
indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
# Send an additional Content-Length header when it is known
self.send_header("Content-Length", str(len(response_json)))
self.end_headers()
self.wfile.write(response_json)
self.wfile.flush()
def handle_stream(
self,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
Generate response to prompt and foward it to the client using a Server
Sent Events (SSE) stream.
Args:
prompt (mx.array): The tokenized prompt
stop_id_sequences (List[List[int]]): A list of stop words passed to
the stopping_criteria function
"""
# No additional headers are needed, call end_headers
self.end_headers()
detokenizer = self.tokenizer.detokenizer
detokenizer.reset()
tokens = []
stop_sequence_suffix = None
logging.debug(f"Starting stream:")
prompt = self.get_prompt_cache(prompt)
for _, (token, _) in zip(
range(self.max_tokens),
generate_step(
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
prompt_cache=self.prompt_cache.cache,
),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
tokens.append(token)
stop_condition = stopping_criteria(
tokens,
stop_id_sequences,
self.tokenizer.eos_token_id,
)
if stop_condition.stop_met:
if stop_condition.trim_length:
stop_sequence_suffix = self.tokenizer.decode(
tokens[-stop_condition.trim_length :]
) )
break ):
continue
# If the end of tokens overlaps with a stop sequence, generate new elif segment:
# tokens until we know if the stop sequence is hit or not response = self.generate_response(segment, None)
if any( self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
(sequence_overlap(tokens, sequence) for sequence in stop_id_sequences) self.wfile.flush()
):
continue
new_text = detokenizer.last_segment
if new_text:
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.prompt_cache.tokens.extend(tokens) self.prompt_cache.tokens.extend(tokens)
# check is there any remaining text to send gen_time = time.perf_counter() - tic
detokenizer.finalize() prompt_tps = len(prompt) / prompt_time
last_segment = detokenizer.last_segment gen_tps = len(tokens) / gen_time
if last_segment: peak_mem = mx.metal.get_peak_memory() / 1e9
if stop_sequence_suffix is not None: logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec")
last_segment = last_segment[: -len(stop_sequence_suffix)] logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec")
response = self.generate_response(last_segment, "length") logging.debug(f"Peak memory: {peak_mem:.3f} GB")
if self.stream:
response = self.generate_response(segment, finish_reason)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush() self.wfile.flush()
if self.stream_options is not None and self.stream_options["include_usage"]:
response = self.completion_usage_response(len(prompt), len(tokens))
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.wfile.write("data: [DONE]\n\n".encode())
self.wfile.flush()
else:
response = self.generate_response(
text,
finish_reason,
len(prompt),
len(tokens),
token_logprobs=token_logprobs,
top_tokens=top_tokens,
tokens=tokens,
)
response_json = json.dumps(response).encode()
indent = "\t" # Backslashes can't be inside of f-strings
logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}")
if self.stream_options is not None and self.stream_options["include_usage"]: # Send an additional Content-Length header when it is known
response = self.completion_usage_response(len(prompt), len(tokens)) self.send_header("Content-Length", str(len(response_json)))
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.end_headers()
self.wfile.write(response_json)
self.wfile.write("data: [DONE]\n\n".encode()) self.wfile.flush()
self.wfile.flush()
def completion_usage_response( def completion_usage_response(
self, self,

View File

@ -409,7 +409,7 @@ def train(
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start) tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 2**30 peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0: if rank == 0:
print( print(
f"Iter {it}: Train loss {train_loss:.3f}, " f"Iter {it}: Train loss {train_loss:.3f}, "

View File

@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer
# Local imports # Local imports
from .models import cache from .models import cache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .sample_utils import make_logits_processors, make_sampler
from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model from .tuner.utils import dequantize as dequantize_model
from .tuner.utils import load_adapters from .tuner.utils import load_adapters
@ -34,6 +34,9 @@ MODEL_REMAPPING = {
MAX_FILE_SIZE_GB = 5 MAX_FILE_SIZE_GB = 5
# A stream on the default device just for generation
generation_stream = mx.new_stream(mx.default_device())
class ModelNotFoundError(Exception): class ModelNotFoundError(Exception):
def __init__(self, message): def __init__(self, message):
@ -137,29 +140,6 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path
return model_path return model_path
def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float):
"""
Apply repetition penalty to specific logits based on the given context.
Paper: https://arxiv.org/abs/1909.05858
Args:
logits (mx.array): The logits produced by the language model.
tokens (mx.array): A list of N previous tokens.
penalty (float): The repetition penalty factor to be applied.
Returns:
logits (mx.array): Logits with repetition penalty applied to generated tokens.
"""
if len(tokens) > 0:
selected_logits = logits[:, tokens]
selected_logits = mx.where(
selected_logits < 0, selected_logits * penalty, selected_logits / penalty
)
logits[:, tokens] = selected_logits
return logits
def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits):
if ( if (
kv_bits is not None kv_bits is not None
@ -185,7 +165,7 @@ def generate_step(
max_kv_size: Optional[int] = None, max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None, prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None, logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
kv_bits: Optional[int] = None, kv_bits: Optional[int] = None,
kv_group_size: int = 64, kv_group_size: int = 64,
quantized_kv_start: int = 0, quantized_kv_start: int = 0,
@ -214,7 +194,7 @@ def generate_step(
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place. provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias. logit_bias (dictionary, optional): Additive logit bias.
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed A list of functions that take tokens and logits and return the processed
logits. Default: ``None``. logits. Default: ``None``.
kv_bits (int, optional): Number of bits to use for KV cache quantization. kv_bits (int, optional): Number of bits to use for KV cache quantization.
@ -224,53 +204,9 @@ def generate_step(
when ``kv_bits`` is non-None. Default: ``0``. when ``kv_bits`` is non-None. Default: ``0``.
Yields: Yields:
Generator[Tuple[mx.array, mx.array], None, None]: A generator producing Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
one token and a vector of log probabilities.
""" """
def sample(logits: mx.array) -> Tuple[mx.array, float]:
logprobs = logits - mx.logsumexp(logits)
if temp == 0:
token = mx.argmax(logits, axis=-1)
else:
if top_p > 0 and top_p < 1.0:
token = top_p_sampling(logits, top_p, temp)
elif min_p != 0.0:
token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp)
else:
token = categorical_sampling(logits, temp)
return token, logprobs
if repetition_penalty and (
repetition_penalty < 0 or not isinstance(repetition_penalty, float)
):
raise ValueError(
f"repetition_penalty must be a non-negative float, got {repetition_penalty}"
)
logits_processor = logits_processor or []
if repetition_penalty:
def repetition_penalty_processor(tokens, logits):
return apply_repetition_penalty(
logits, tokens[-repetition_context_size:], repetition_penalty
)
logits_processor.append(repetition_penalty_processor)
if logit_bias:
indices = mx.array(list(logit_bias.keys()))
values = mx.array(list(logit_bias.values()))
def logit_bias_processor(_, logits):
logits[:, indices] += values
return logits
logits_processor.append(logit_bias_processor)
y = prompt y = prompt
tokens = None tokens = None
@ -283,24 +219,31 @@ def generate_step(
elif len(prompt_cache) != len(model.layers): elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.") raise ValueError("Wrong number of layers in the prompt cache.")
sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep)
logits_processors = logits_processors or []
logits_processors.extend(
make_logits_processors(logit_bias, repetition_penalty, repetition_context_size)
)
def _step(y): def _step(y):
with mx.stream(generation_stream):
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
logits = model(y[None], cache=prompt_cache) if logits_processors:
logits = logits[:, -1, :] nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
if logits_processor: for processor in logits_processors:
nonlocal tokens logits = processor(tokens, logits)
tokens = mx.concat([tokens, y]) if tokens is not None else y
for processor in logits_processor: maybe_quantize_kv_cache(
logits = processor(tokens, logits) prompt_cache, quantized_kv_start, kv_group_size, kv_bits
)
maybe_quantize_kv_cache( logprobs = logits - mx.logsumexp(logits, keepdims=True)
prompt_cache, quantized_kv_start, kv_group_size, kv_bits y = sampler(logprobs)
) return y, logprobs.squeeze(0)
y, logprobs = sample(logits)
return y, logprobs.squeeze(0)
while y.size > prefill_step_size: while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache) model(y[:prefill_step_size][None], cache=prompt_cache)
@ -325,43 +268,51 @@ def generate_step(
def stream_generate( def stream_generate(
model: nn.Module, model: nn.Module,
tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper],
prompt: str, prompt: Union[str, List[int]],
max_tokens: int = 100, max_tokens: int = 100,
**kwargs, **kwargs,
) -> Union[str, Generator[str, None, None]]: ) -> Generator[Tuple[str, int, mx.array], None, None]:
""" """
A generator producing text based on the given prompt from the model. A generator producing text based on the given prompt from the model.
Args: Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation. model (nn.Module): The model to use for generation.
max_tokens (int): The ma tokenizer (PreTrainedTokenizer): The tokenizer.
prompt (Union[str, List[int]]): The input prompt string or integer tokens.
max_tokens (int): The maximum number of tokens. Default: ``100``.
kwargs: The remaining options get passed to :func:`generate_step`. kwargs: The remaining options get passed to :func:`generate_step`.
See :func:`generate_step` for more details. See :func:`generate_step` for more details.
Yields: Yields:
Generator[Tuple[mx.array, mx.array]]: A generator producing text. Tuple[str, int, mx.array]:
The next text segment, token, and vector of log probabilities.
""" """
if not isinstance(tokenizer, TokenizerWrapper): if not isinstance(tokenizer, TokenizerWrapper):
tokenizer = TokenizerWrapper(tokenizer) tokenizer = TokenizerWrapper(tokenizer)
prompt_tokens = mx.array(tokenizer.encode(prompt)) prompt_tokens = mx.array(
prompt if isinstance(prompt, list) else tokenizer.encode(prompt)
)
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
detokenizer.reset() with wired_limit(model, [generation_stream]):
for n, (token, _) in zip( detokenizer.reset()
range(max_tokens), for n, (token, logits) in zip(
generate_step(prompt_tokens, model, **kwargs), range(max_tokens),
): generate_step(prompt_tokens, model, **kwargs),
if token == tokenizer.eos_token_id: ):
break if token == tokenizer.eos_token_id:
detokenizer.add_token(token) break
# Yield the last segment if streaming detokenizer.add_token(token)
yield detokenizer.last_segment
detokenizer.finalize() if n == (max_tokens - 1):
yield detokenizer.last_segment break
yield detokenizer.last_segment, token, logits
detokenizer.finalize()
yield detokenizer.last_segment, token, logits
def generate( def generate(
@ -372,7 +323,7 @@ def generate(
verbose: bool = False, verbose: bool = False,
formatter: Optional[Callable] = None, formatter: Optional[Callable] = None,
**kwargs, **kwargs,
) -> Union[str, Generator[str, None, None]]: ) -> str:
""" """
Generate a complete response from the model. Generate a complete response from the model.
@ -398,7 +349,7 @@ def generate(
prompt_tokens = mx.array(tokenizer.encode(prompt)) prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
with wired_limit(model): with wired_limit(model, [generation_stream]):
tic = time.perf_counter() tic = time.perf_counter()
detokenizer.reset() detokenizer.reset()
for n, (token, logprobs) in zip( for n, (token, logprobs) in zip(
@ -416,8 +367,7 @@ def generate(
if formatter: if formatter:
# We have to finalize so that the prob corresponds to the last segment # We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize() detokenizer.finalize()
with mx.stream(mx.cpu): prob = mx.exp(logprobs[token]).item()
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob) formatter(detokenizer.last_segment, prob)
else: else:
print(detokenizer.last_segment, end="", flush=True) print(detokenizer.last_segment, end="", flush=True)
@ -438,7 +388,7 @@ def generate(
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
) )
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30 peak_mem = mx.metal.get_peak_memory() / 1e9
print(f"Peak memory: {peak_mem:.3f} GB") print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text return detokenizer.text
@ -623,7 +573,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str):
f""" f"""
# {upload_repo} # {upload_repo}
The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**. The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was
converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path})
using mlx-lm version **{__version__}**.
## Use with mlx ## Use with mlx

View File

@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase):
"hello", "hello",
max_tokens=5, max_tokens=5,
verbose=False, verbose=False,
logits_processor=[logits_processor], logits_processors=[logits_processor],
) )
self.assertEqual(len(all_toks), len(init_toks) + 5) self.assertEqual(len(all_toks), len(init_toks) + 5)

View File

@ -299,7 +299,7 @@ class TestPromptCache(unittest.TestCase):
): ):
i += 1 i += 1
self.assertEqual(tok, toks[i]) self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=1e-2)) self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2))
if __name__ == "__main__": if __name__ == "__main__":