diff --git a/llms/README.md b/llms/README.md index f539988a..eeb3ed6a 100644 --- a/llms/README.md +++ b/llms/README.md @@ -101,7 +101,8 @@ To see a description of all the arguments you can do: #### Streaming For streaming generation, use the `stream_generate` function. This returns a -generator object which streams the output text. For example, +generator object which streams the output text, token, and log probabilities. +For example, ```python from mlx_lm import load, stream_generate @@ -116,7 +117,7 @@ prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) -for t in stream_generate(model, tokenizer, prompt, max_tokens=512): +for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512): print(t, end="", flush=True) print() ``` @@ -221,6 +222,7 @@ Here are a few examples of Hugging Face models that work with this example: - [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct) - [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b) - [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b) +- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct) Most [Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending), diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 7bb06411..987b640d 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -152,6 +152,7 @@ def main(): model(y[:step_size][None], cache=cache) mx.eval([c.state for c in cache]) + mx.metal.clear_cache() processed += min(y.size, step_size) y = y[step_size:] current = time.time() @@ -165,14 +166,13 @@ def main(): ) print() - print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") + print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB") print("Saving...") metadata = {} metadata["model"] = args.model metadata["chat_template"] = tokenizer.chat_template metadata["tokenizer_config"] = json.dumps(tokenizer_config) - print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") save_prompt_cache(args.prompt_cache_file, cache, metadata) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index ea1a99c7..c03056a6 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -11,6 +11,7 @@ from .utils import load, stream_generate DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 +DEFAULT_MAX_TOKENS = 256 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" @@ -41,6 +42,13 @@ def setup_arg_parser(): help="Set the maximum key-value cache size", default=None, ) + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=DEFAULT_MAX_TOKENS, + help="Maximum number of tokens to generate", + ) return parser @@ -66,10 +74,11 @@ def main(): prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - for response in stream_generate( + for response, *_ in stream_generate( model, tokenizer, prompt, + args.max_tokens, temp=args.temp, top_p=args.top_p, prompt_cache=prompt_cache, diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 29976da2..51169def 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -13,6 +13,8 @@ DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 +DEFAULT_MIN_P = 0.0 +DEFAULT_MIN_TOKENS_TO_KEEP = 1 DEFAULT_SEED = 0 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" DEFAULT_QUANTIZED_KV_START = 5000 @@ -52,6 +54,7 @@ def setup_arg_parser(): ) parser.add_argument( "--prompt", + "-p", default=DEFAULT_PROMPT, help="Message to be processed by the model ('-' reads from stdin)", ) @@ -68,6 +71,15 @@ def setup_arg_parser(): parser.add_argument( "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" ) + parser.add_argument( + "--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p" + ) + parser.add_argument( + "--min-tokens-to-keep", + type=float, + default=DEFAULT_MIN_TOKENS_TO_KEEP, + help="Minimum tokens to keep for min-p sampling.", + ) parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") parser.add_argument( "--ignore-chat-template", @@ -247,6 +259,8 @@ def main(): formatter=formatter, temp=args.temp, top_p=args.top_p, + min_p=args.min_p, + min_tokens_to_keep=args.min_tokens_to_keep, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, kv_bits=args.kv_bits, diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index cda41c79..f02f49b1 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -42,7 +42,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): if cache is not None and cache[0] is not None: c = cache[0] if hasattr(c, "max_size"): - offset = min(c.max_size - 1, c.offset) + offset = min(c.max_size, c.offset) window_size = c.max_size else: offset = c.offset diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 1cd5289d..14026f0c 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -325,9 +325,9 @@ class RotatingKVCache(_BaseCache): self.keys = self._temporal_order(self.keys) self.values = self._temporal_order(self.values) - # The largest size is self.max_size + S - 1 to ensure + # The largest size is self.max_size + S to ensure # every token gets at least self.max_size context - trim_size = self._idx - self.max_size + 1 + trim_size = self._idx - self.max_size self.keys = self._trim(trim_size, self.keys, keys) self.values = self._trim(trim_size, self.values, values) self.offset += keys.shape[2] diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 84f498e9..f2414660 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs): use_conv_bias: bool time_step_rank: int tie_word_embeddings: bool = True + use_bcdt_rms: bool = False + mixer_rms_eps: float = 1e-6 def __post_init__(self): if not hasattr(self, "hidden_size") and hasattr(self, "d_model"): @@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs): if self.time_step_rank == "auto": self.time_step_rank = math.ceil(self.hidden_size / 16) + if self.model_type == "falcon_mamba": + self.use_bcdt_rms = True class DepthWiseConv1d(nn.Module): @@ -83,6 +87,11 @@ class MambaBlock(nn.Module): self.intermediate_size = args.intermediate_size self.time_step_rank = int(args.time_step_rank) self.use_conv_bias = args.use_conv_bias + self.use_bcdt_rms = args.use_bcdt_rms + if self.use_bcdt_rms: + self.mixer_norm = lambda x: mx.fast.rms_norm( + x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps + ) self.in_proj = nn.Linear( self.hidden_size, self.intermediate_size * 2, bias=args.use_bias @@ -126,6 +135,8 @@ class MambaBlock(nn.Module): ], axis=-1, ) + if self.use_bcdt_rms: + delta, B, C = map(self.mixer_norm, (delta, B, C)) delta = nn.softplus(self.dt_proj(delta)) new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) if state is not None: diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 20b008fa..c27b52d8 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -1,10 +1,83 @@ # Copyright © 2023-2024 Apple Inc. from functools import partial +from typing import Callable, Dict, Optional import mlx.core as mx +def make_sampler( + temp: float = 0.0, + top_p: float = 0.0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, +) -> Callable[mx.array, mx.array]: + """ + Make a sampler function for use with ``generate_step``. + + Args: + temp (float): The temperature for sampling, if 0 the argmax is used. + Default: ``0``. + top_p (float, optional): Nulceus sampling, higher means model considers + more less likely words. + min_p (float, optional): The minimum value (scaled by the top token's + probability) that a token probability must have to be considered. + min_tokens_to_keep (int, optional): Minimum number of tokens that cannot + be filtered by min_p sampling. + + Returns: + Callable[mx.array, mx.array]: + A sampler which takes log-probabilities and returns tokens. + """ + if temp == 0: + return lambda x: mx.argmax(x, axis=-1) + elif top_p > 0 and top_p < 1.0: + return lambda x: top_p_sampling(x, top_p, temp) + elif min_p != 0.0: + return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp) + else: + return lambda x: categorical_sampling(x, temp) + + +def make_logits_processors( + logit_bias: Optional[Dict[int, float]] = None, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = 20, +): + """ + Make logits processors for use with ``generate_step``. + + Args: + repetition_penalty (float, optional): The penalty factor for repeating + tokens. + repetition_context_size (int, optional): The number of tokens to + consider for repetition penalty. Default: ``20``. + logit_bias (dictionary, optional): Additive logit bias. + + Returns: + List[Callable[[mx.array, mx.array], mx.array]]: + A list of logits processors. Each processor in the list is a + callable which takes an array of tokens and an array of logits + and returns the updated logits. + """ + logits_processors = [] + if logit_bias: + indices = mx.array(list(logit_bias.keys())) + values = mx.array(list(logit_bias.values())) + + def logit_bias_processor(_, logits): + logits[:, indices] += values + return logits + + logits_processors.append(logit_bias_processor) + + if repetition_penalty and repetition_penalty != 0.0: + logits_processors.append( + make_repetition_penalty(repetition_penalty, repetition_context_size) + ) + return logits_processors + + @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def min_p_sampling( logits: mx.array, @@ -100,3 +173,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def categorical_sampling(logits, temp): return mx.random.categorical(logits * (1 / temp)) + + +def make_repetition_penalty(penalty: float, context_size: int = 20): + """ + Make repetition penalty processor. + + Paper: https://arxiv.org/abs/1909.05858 + + Args: + penalty (float): The repetition penalty factor to be applied. + context_size (int): The number of previous tokens to use. + Default: ``20``. + + Returns: + Callable[[mx.array, List[int]], mx.array]: + The repetition penalty processor. + """ + if penalty < 0 or not isinstance(penalty, float): + raise ValueError(f"penalty must be a non-negative float, got {penalty}") + + def repetition_penalty_processor(tokens, logits): + if len(tokens) > 0: + tokens = tokens[-context_size:] + selected_logits = logits[:, tokens] + selected_logits = mx.where( + selected_logits < 0, + selected_logits * penalty, + selected_logits / penalty, + ) + logits[:, tokens] = selected_logits + return logits + + return repetition_penalty_processor diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index ec659969..c1365b36 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -27,7 +27,7 @@ from huggingface_hub import scan_cache_dir from ._version import __version__ from .models.cache import make_prompt_cache -from .utils import generate_step, load +from .utils import load, stream_generate def get_system_fingerprint(): @@ -64,7 +64,7 @@ def stopping_criteria( end if it has (`trim_length`). """ if tokens and tokens[-1] == eos_token_id: - return StopCondition(stop_met=True, trim_length=1) + return StopCondition(stop_met=True, trim_length=0) for stop_ids in stop_id_sequences: if len(tokens) >= len(stop_ids): @@ -253,7 +253,7 @@ class APIHandler(BaseHTTPRequestHandler): self.max_tokens = self.body.get("max_completion_tokens", None) if self.max_tokens is None: self.max_tokens = self.body.get("max_tokens", 512) - self.temperature = self.body.get("temperature", 1.0) + self.temperature = self.body.get("temperature", 0.0) self.top_p = self.body.get("top_p", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0) self.repetition_context_size = self.body.get("repetition_context_size", 20) @@ -290,10 +290,7 @@ class APIHandler(BaseHTTPRequestHandler): # Call endpoint specific method prompt = endpoints[self.path]() - - # Call method based on response type - method = self.handle_stream if self.stream else self.handle_completion - method(prompt, stop_id_sequences) + self.handle_completion(prompt, stop_id_sequences) def validate_model_parameters(self): """ @@ -452,32 +449,40 @@ class APIHandler(BaseHTTPRequestHandler): stop_id_sequences (List[List[int]]): A list of stop words passed to the stopping_criteria function """ - detokenizer = self.tokenizer.detokenizer - detokenizer.reset() tokens = [] finish_reason = "length" stop_sequence_suffix = None - logging.debug(f"Starting completion:") + if self.stream: + self.end_headers() + logging.debug(f"Starting stream:") + else: + logging.debug(f"Starting completion:") token_logprobs = [] top_tokens = [] prompt = self.get_prompt_cache(prompt) - for _, (token, logprobs) in zip( - range(self.max_tokens), - generate_step( - prompt=mx.array(prompt), + text = "" + tic = time.perf_counter() + for n, (segment, token, logprobs) in enumerate( + stream_generate( model=self.model, + tokenizer=self.tokenizer, + prompt=prompt, + max_tokens=self.max_tokens, temp=self.temperature, - top_p=self.top_p, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, logit_bias=self.logit_bias, prompt_cache=self.prompt_cache.cache, ), ): - detokenizer.add_token(token) - logging.debug(detokenizer.text) + if n == 0: + prompt_time = time.perf_counter() - tic + tic = time.perf_counter() + + text += segment + logging.debug(text) tokens.append(token) if self.logprobs > 0: @@ -498,121 +503,63 @@ class APIHandler(BaseHTTPRequestHandler): stop_sequence_suffix = self.tokenizer.decode( tokens[-stop_condition.trim_length :] ) + text = text[: -len(stop_sequence_suffix)] break - self.prompt_cache.tokens.extend(tokens) - detokenizer.finalize() - text = ( - detokenizer.text - if stop_sequence_suffix is None - else detokenizer.text[: -len(stop_sequence_suffix)] - ) - response = self.generate_response( - text, - finish_reason, - len(prompt), - len(tokens), - token_logprobs=token_logprobs, - top_tokens=top_tokens, - tokens=tokens, - ) - - response_json = json.dumps(response).encode() - indent = "\t" # Backslashes can't be inside of f-strings - logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") - - # Send an additional Content-Length header when it is known - self.send_header("Content-Length", str(len(response_json))) - self.end_headers() - - self.wfile.write(response_json) - self.wfile.flush() - - def handle_stream( - self, - prompt: List[int], - stop_id_sequences: List[List[int]], - ): - """ - Generate response to prompt and foward it to the client using a Server - Sent Events (SSE) stream. - - Args: - prompt (mx.array): The tokenized prompt - stop_id_sequences (List[List[int]]): A list of stop words passed to - the stopping_criteria function - """ - # No additional headers are needed, call end_headers - self.end_headers() - - detokenizer = self.tokenizer.detokenizer - detokenizer.reset() - tokens = [] - - stop_sequence_suffix = None - logging.debug(f"Starting stream:") - - prompt = self.get_prompt_cache(prompt) - - for _, (token, _) in zip( - range(self.max_tokens), - generate_step( - prompt=mx.array(prompt), - model=self.model, - temp=self.temperature, - top_p=self.top_p, - repetition_penalty=self.repetition_penalty, - repetition_context_size=self.repetition_context_size, - prompt_cache=self.prompt_cache.cache, - ), - ): - detokenizer.add_token(token) - logging.debug(detokenizer.text) - tokens.append(token) - - stop_condition = stopping_criteria( - tokens, - stop_id_sequences, - self.tokenizer.eos_token_id, - ) - if stop_condition.stop_met: - if stop_condition.trim_length: - stop_sequence_suffix = self.tokenizer.decode( - tokens[-stop_condition.trim_length :] + if self.stream: + # If the end of tokens overlaps with a stop sequence, generate new + # tokens until we know if the stop sequence is hit or not + if any( + ( + sequence_overlap(tokens, sequence) + for sequence in stop_id_sequences ) - break - - # If the end of tokens overlaps with a stop sequence, generate new - # tokens until we know if the stop sequence is hit or not - if any( - (sequence_overlap(tokens, sequence) for sequence in stop_id_sequences) - ): - continue - - new_text = detokenizer.last_segment - if new_text: - response = self.generate_response(new_text, None) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() + ): + continue + elif segment: + response = self.generate_response(segment, None) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() self.prompt_cache.tokens.extend(tokens) - # check is there any remaining text to send - detokenizer.finalize() - last_segment = detokenizer.last_segment - if last_segment: - if stop_sequence_suffix is not None: - last_segment = last_segment[: -len(stop_sequence_suffix)] - response = self.generate_response(last_segment, "length") + gen_time = time.perf_counter() - tic + prompt_tps = len(prompt) / prompt_time + gen_tps = len(tokens) / gen_time + peak_mem = mx.metal.get_peak_memory() / 1e9 + logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec") + logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec") + logging.debug(f"Peak memory: {peak_mem:.3f} GB") + + if self.stream: + response = self.generate_response(segment, finish_reason) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.flush() + if self.stream_options is not None and self.stream_options["include_usage"]: + response = self.completion_usage_response(len(prompt), len(tokens)) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + self.wfile.write("data: [DONE]\n\n".encode()) + self.wfile.flush() + else: + response = self.generate_response( + text, + finish_reason, + len(prompt), + len(tokens), + token_logprobs=token_logprobs, + top_tokens=top_tokens, + tokens=tokens, + ) + response_json = json.dumps(response).encode() + indent = "\t" # Backslashes can't be inside of f-strings + logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") - if self.stream_options is not None and self.stream_options["include_usage"]: - response = self.completion_usage_response(len(prompt), len(tokens)) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - - self.wfile.write("data: [DONE]\n\n".encode()) - self.wfile.flush() + # Send an additional Content-Length header when it is known + self.send_header("Content-Length", str(len(response_json))) + self.end_headers() + self.wfile.write(response_json) + self.wfile.flush() def completion_usage_response( self, diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 568a672d..9d390733 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -6,12 +6,6 @@ from transformers import AutoTokenizer REPLACEMENT_CHAR = "\ufffd" -def _remove_space(x): - if x and x[0] == " ": - return x[1:] - return x - - class StreamingDetokenizer: """The streaming detokenizer interface so that we can detokenize one token at a time. @@ -123,42 +117,42 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): def __init__(self, tokenizer, trim_space=True): self.trim_space = trim_space + self._sep = "\u2581".encode() # Extract the tokens in a list from id to text self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1) for value, tokenid in tokenizer.vocab.items(): - self.tokenmap[tokenid] = value - - # Replace bytes with their value - for i in range(len(self.tokenmap)): - if self.tokenmap[i].startswith("<0x"): - self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16)) + if value.startswith("<0x"): + # Replace bytes with their value + self.tokenmap[tokenid] = bytes([int(value[3:5], 16)]) + else: + self.tokenmap[tokenid] = value.encode() self.reset() def reset(self): self.offset = 0 - self._unflushed = "" + self._unflushed = b"" self.text = "" self.tokens = [] + def _flush(self): + text = self._unflushed.replace(self._sep, b" ").decode("utf-8") + if not self.text and self.trim_space and text and text[0] == " ": + text = text[1:] + self.text += text + def add_token(self, token): v = self.tokenmap[token] - if v[0] == "\u2581": - if self.text or not self.trim_space: - self.text += self._unflushed.replace("\u2581", " ") - else: - self.text = _remove_space(self._unflushed.replace("\u2581", " ")) + if v.startswith(self._sep): + self._flush() self._unflushed = v else: self._unflushed += v def finalize(self): - if self.text or not self.trim_space: - self.text += self._unflushed.replace("\u2581", " ") - else: - self.text = _remove_space(self._unflushed.replace("\u2581", " ")) - self._unflushed = "" + self._flush() + self._unflushed = b"" class BPEStreamingDetokenizer(StreamingDetokenizer): diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 38619d95..21b1af18 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -285,7 +285,7 @@ def train( it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens - peak_mem = mx.metal.get_peak_memory() / 2**30 + peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: print( f"Iter {it}: Train loss {train_loss:.3f}, " diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b9fc202d..8893b570 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -20,7 +20,7 @@ from transformers import PreTrainedTokenizer # Local imports from .models import cache -from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling +from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model from .tuner.utils import load_adapters @@ -29,10 +29,14 @@ from .tuner.utils import load_adapters MODEL_REMAPPING = { "mistral": "llama", # mistral is compatible with llama "phi-msft": "phixtral", + "falcon_mamba": "mamba", } MAX_FILE_SIZE_GB = 5 +# A stream on the default device just for generation +generation_stream = mx.new_stream(mx.default_device()) + class ModelNotFoundError(Exception): def __init__(self, message): @@ -136,29 +140,6 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path return model_path -def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float): - """ - Apply repetition penalty to specific logits based on the given context. - - Paper: https://arxiv.org/abs/1909.05858 - - Args: - logits (mx.array): The logits produced by the language model. - tokens (mx.array): A list of N previous tokens. - penalty (float): The repetition penalty factor to be applied. - - Returns: - logits (mx.array): Logits with repetition penalty applied to generated tokens. - """ - if len(tokens) > 0: - selected_logits = logits[:, tokens] - selected_logits = mx.where( - selected_logits < 0, selected_logits * penalty, selected_logits / penalty - ) - logits[:, tokens] = selected_logits - return logits - - def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): if ( kv_bits is not None @@ -184,7 +165,7 @@ def generate_step( max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, logit_bias: Optional[Dict[int, float]] = None, - logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, @@ -213,7 +194,7 @@ def generate_step( prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if provided, the cache will be updated in place. logit_bias (dictionary, optional): Additive logit bias. - logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed logits. Default: ``None``. kv_bits (int, optional): Number of bits to use for KV cache quantization. @@ -223,53 +204,9 @@ def generate_step( when ``kv_bits`` is non-None. Default: ``0``. Yields: - Generator[Tuple[mx.array, mx.array], None, None]: A generator producing - one token and a vector of log probabilities. + Tuple[mx.array, mx.array]: One token and a vector of log probabilities. """ - def sample(logits: mx.array) -> Tuple[mx.array, float]: - logprobs = logits - mx.logsumexp(logits) - - if temp == 0: - token = mx.argmax(logits, axis=-1) - else: - if top_p > 0 and top_p < 1.0: - token = top_p_sampling(logits, top_p, temp) - elif min_p != 0.0: - token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp) - else: - token = categorical_sampling(logits, temp) - - return token, logprobs - - if repetition_penalty and ( - repetition_penalty < 0 or not isinstance(repetition_penalty, float) - ): - raise ValueError( - f"repetition_penalty must be a non-negative float, got {repetition_penalty}" - ) - - logits_processor = logits_processor or [] - - if repetition_penalty: - - def repetition_penalty_processor(tokens, logits): - return apply_repetition_penalty( - logits, tokens[-repetition_context_size:], repetition_penalty - ) - - logits_processor.append(repetition_penalty_processor) - - if logit_bias: - indices = mx.array(list(logit_bias.keys())) - values = mx.array(list(logit_bias.values())) - - def logit_bias_processor(_, logits): - logits[:, indices] += values - return logits - - logits_processor.append(logit_bias_processor) - y = prompt tokens = None @@ -282,24 +219,31 @@ def generate_step( elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") + sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) + logits_processors = logits_processors or [] + logits_processors.extend( + make_logits_processors(logit_bias, repetition_penalty, repetition_context_size) + ) + def _step(y): + with mx.stream(generation_stream): + logits = model(y[None], cache=prompt_cache) + logits = logits[:, -1, :] - logits = model(y[None], cache=prompt_cache) - logits = logits[:, -1, :] + if logits_processors: + nonlocal tokens + tokens = mx.concat([tokens, y]) if tokens is not None else y - if logits_processor: - nonlocal tokens - tokens = mx.concat([tokens, y]) if tokens is not None else y + for processor in logits_processors: + logits = processor(tokens, logits) - for processor in logits_processor: - logits = processor(tokens, logits) + maybe_quantize_kv_cache( + prompt_cache, quantized_kv_start, kv_group_size, kv_bits + ) - maybe_quantize_kv_cache( - prompt_cache, quantized_kv_start, kv_group_size, kv_bits - ) - - y, logprobs = sample(logits) - return y, logprobs.squeeze(0) + logprobs = logits - mx.logsumexp(logits, keepdims=True) + y = sampler(logprobs) + return y, logprobs.squeeze(0) while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) @@ -324,43 +268,51 @@ def generate_step( def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: str, + prompt: Union[str, List[int]], max_tokens: int = 100, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> Generator[Tuple[str, int, mx.array], None, None]: """ A generator producing text based on the given prompt from the model. Args: - prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - max_tokens (int): The ma + tokenizer (PreTrainedTokenizer): The tokenizer. + prompt (Union[str, List[int]]): The input prompt string or integer tokens. + max_tokens (int): The maximum number of tokens. Default: ``100``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. Yields: - Generator[Tuple[mx.array, mx.array]]: A generator producing text. + Tuple[str, int, mx.array]: + The next text segment, token, and vector of log probabilities. """ if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - prompt_tokens = mx.array(tokenizer.encode(prompt)) + prompt_tokens = mx.array( + prompt if isinstance(prompt, list) else tokenizer.encode(prompt) + ) detokenizer = tokenizer.detokenizer - detokenizer.reset() - for n, (token, _) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if token == tokenizer.eos_token_id: - break - detokenizer.add_token(token) + with wired_limit(model, [generation_stream]): + detokenizer.reset() + for n, (token, logits) in zip( + range(max_tokens), + generate_step(prompt_tokens, model, **kwargs), + ): + if token == tokenizer.eos_token_id: + break - # Yield the last segment if streaming - yield detokenizer.last_segment + detokenizer.add_token(token) - detokenizer.finalize() - yield detokenizer.last_segment + if n == (max_tokens - 1): + break + + yield detokenizer.last_segment, token, logits + + detokenizer.finalize() + yield detokenizer.last_segment, token, logits def generate( @@ -371,7 +323,7 @@ def generate( verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> str: """ Generate a complete response from the model. @@ -397,7 +349,7 @@ def generate( prompt_tokens = mx.array(tokenizer.encode(prompt)) detokenizer = tokenizer.detokenizer - with wired_limit(model): + with wired_limit(model, [generation_stream]): tic = time.perf_counter() detokenizer.reset() for n, (token, logprobs) in zip( @@ -415,8 +367,7 @@ def generate( if formatter: # We have to finalize so that the prob corresponds to the last segment detokenizer.finalize() - with mx.stream(mx.cpu): - prob = mx.exp(logprobs[token]).item() + prob = mx.exp(logprobs[token]).item() formatter(detokenizer.last_segment, prob) else: print(detokenizer.last_segment, end="", flush=True) @@ -437,7 +388,7 @@ def generate( f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" ) print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 2**30 + peak_mem = mx.metal.get_peak_memory() / 1e9 print(f"Peak memory: {peak_mem:.3f} GB") return detokenizer.text @@ -622,7 +573,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): f""" # {upload_repo} - The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**. + The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was + converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) + using mlx-lm version **{__version__}**. ## Use with mlx diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index 68f1670b..e0a372a9 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase): "hello", max_tokens=5, verbose=False, - logits_processor=[logits_processor], + logits_processors=[logits_processor], ) self.assertEqual(len(all_toks), len(init_toks) + 5) diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 1e57bd86..0867ab56 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -299,7 +299,7 @@ class TestPromptCache(unittest.TestCase): ): i += 1 self.assertEqual(tok, toks[i]) - self.assertTrue(mx.allclose(logits, all_logits[i], rtol=1e-2)) + self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2)) if __name__ == "__main__": diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py index 3c93fbe2..9c30d51e 100644 --- a/llms/tests/test_tokenizers.py +++ b/llms/tests/test_tokenizers.py @@ -42,6 +42,9 @@ class TestTokenizers(unittest.TestCase): text += detokenizer.last_segment self.assertEqual(text, expected_text) + tokens = tokenizer.encode("こんにちは!私の名前はAI") + check(tokens) + tokens = tokenizer.encode("a ,b") check(tokens) diff --git a/whisper/README.md b/whisper/README.md index ac6e95f6..cd3bc684 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -25,7 +25,7 @@ pip install mlx-whisper At its simplest: -``` +```sh mlx_whisper audio_file.mp3 ``` @@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There are many other supported command line options. To see them all, run `mlx_whisper -h`. +You can also pipe the audio content of other programs via stdin: + +```sh +some-process | mlx_whisper - +``` + +The default output file name will be `content.*`. You can specify the name with +the `--output-name` flag. + #### API Transcribe audio with: @@ -103,7 +112,7 @@ python convert.py --help ``` By default, the conversion script will make the directory `mlx_models` -and save the converted `weights.npz` and `config.json` there. +and save the converted `weights.npz` and `config.json` there. Each time it is run, `convert.py` will overwrite any model in the provided path. To save different models, make sure to set `--mlx-path` to a unique diff --git a/whisper/mlx_whisper/audio.py b/whisper/mlx_whisper/audio.py index e04309c1..c8cca07c 100644 --- a/whisper/mlx_whisper/audio.py +++ b/whisper/mlx_whisper/audio.py @@ -3,7 +3,7 @@ import os from functools import lru_cache from subprocess import CalledProcessError, run -from typing import Union +from typing import Optional, Union import mlx.core as mx import numpy as np @@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token -def load_audio(file: str, sr: int = SAMPLE_RATE): +def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False): """ Open an audio file and read as mono waveform, resampling as necessary @@ -39,19 +39,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): """ # This launches a subprocess to decode audio while down-mixing - # and resampling as necessary. Requires the ffmpeg CLI in PATH. + # and resampling as necessary. Requires the ffmpeg CLI in PATH. + if from_stdin: + cmd = ["ffmpeg", "-i", "pipe:0"] + else: + cmd = ["ffmpeg", "-nostdin", "-i", file] + # fmt: off - cmd = [ - "ffmpeg", - "-nostdin", + cmd.extend([ "-threads", "0", - "-i", file, "-f", "s16le", "-ac", "1", "-acodec", "pcm_s16le", "-ar", str(sr), "-" - ] + ]) # fmt: on try: out = run(cmd, capture_output=True, check=True).stdout diff --git a/whisper/mlx_whisper/cli.py b/whisper/mlx_whisper/cli.py index c2813338..7d08a043 100644 --- a/whisper/mlx_whisper/cli.py +++ b/whisper/mlx_whisper/cli.py @@ -2,9 +2,11 @@ import argparse import os +import pathlib import traceback import warnings +from . import audio from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE from .transcribe import transcribe from .writers import get_writer @@ -27,15 +29,24 @@ def build_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument( - "audio", nargs="+", type=str, help="Audio file(s) to transcribe" - ) + + parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe") + parser.add_argument( "--model", default="mlx-community/whisper-tiny", type=str, help="The model directory or hugging face repo", ) + parser.add_argument( + "--output-name", + type=str, + default=None, + help=( + "The name of transcription/translation output files before " + "--output-format extensions" + ), + ) parser.add_argument( "--output-dir", "-o", @@ -200,6 +211,7 @@ def main(): path_or_hf_repo: str = args.pop("model") output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") + output_name: str = args.pop("output_name") os.makedirs(output_dir, exist_ok=True) writer = get_writer(output_format, output_dir) @@ -219,17 +231,25 @@ def main(): warnings.warn("--max-line-count has no effect without --max-line-width") if writer_args["max_words_per_line"] and writer_args["max_line_width"]: warnings.warn("--max-words-per-line has no effect with --max-line-width") - for audio_path in args.pop("audio"): + + for audio_obj in args.pop("audio"): + if audio_obj == "-": + # receive the contents from stdin rather than read a file + audio_obj = audio.load_audio(from_stdin=True) + + output_name = output_name or "content" + else: + output_name = output_name or pathlib.Path(audio_obj).stem try: result = transcribe( - audio_path, + audio_obj, path_or_hf_repo=path_or_hf_repo, **args, ) - writer(result, audio_path, **writer_args) + writer(result, output_name, **writer_args) except Exception as e: traceback.print_exc() - print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}") + print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}") if __name__ == "__main__": diff --git a/whisper/mlx_whisper/writers.py b/whisper/mlx_whisper/writers.py index 464ead18..cdb35063 100644 --- a/whisper/mlx_whisper/writers.py +++ b/whisper/mlx_whisper/writers.py @@ -1,10 +1,8 @@ # Copyright © 2024 Apple Inc. import json -import os +import pathlib import re -import sys -import zlib from typing import Callable, List, Optional, TextIO @@ -43,15 +41,13 @@ class ResultWriter: self.output_dir = output_dir def __call__( - self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs + self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs ): - audio_basename = os.path.basename(audio_path) - audio_basename = os.path.splitext(audio_basename)[0] - output_path = os.path.join( - self.output_dir, audio_basename + "." + self.extension + output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix( + f".{self.extension}" ) - with open(output_path, "w", encoding="utf-8") as f: + with output_path.open("wt", encoding="utf-8") as f: self.write_result(result, file=f, options=options, **kwargs) def write_result(