diff --git a/flux/README.md b/flux/README.md index 1a17e386..b00a9621 100644 --- a/flux/README.md +++ b/flux/README.md @@ -188,7 +188,7 @@ The adapters are saved in `mlx_output` and can be used directly by the ```shell python txt2image.py --model dev --save-raw --image-size 512x512 --n-images 1 \ - --adapter mlx_output/0001200_adapters.safetensors \ + --adapter mlx_output/final_adapters.safetensors \ --fuse-adapter \ --no-t5-padding \ 'A photo of an sks dog lying on the sand at a beach in Greece' diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 48dcad47..ffdb02d7 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -13,7 +13,7 @@ from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten, tree_map, tree_reduce from PIL import Image -from flux import FluxPipeline, Trainer, load_dataset +from flux import FluxPipeline, Trainer, load_dataset, save_config def generate_progress_images(iteration, flux, args): @@ -43,10 +43,10 @@ def generate_progress_images(iteration, flux, args): im.save(out_file) -def save_adapters(iteration, flux, args): +def save_adapters(adapter_name, flux, args): out_dir = Path(args.output_dir) out_dir.mkdir(parents=True, exist_ok=True) - out_file = out_dir / f"{iteration:07d}_adapters.safetensors" + out_file = out_dir / adapter_name print(f"Saving {str(out_file)}") mx.save_safetensors( @@ -157,6 +157,10 @@ if __name__ == "__main__": parser = setup_arg_parser() args = parser.parse_args() + output_path = Path(args.output_dir) + output_path.mkdir(parents=True, exist_ok=True) + save_config(vars(args), output_path / "adapter_config.json") + # Load the model and set it up for LoRA training. We use the same random # state when creating the LoRA layers so all workers will have the same # initial weights. @@ -278,8 +282,11 @@ if __name__ == "__main__": generate_progress_images(i + 1, flux, args) if (i + 1) % args.checkpoint_every == 0: - save_adapters(i + 1, flux, args) + save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args) if (i + 1) % 10 == 0: losses = [] tic = time.time() + + save_adapters("final_adapters.safetensors", flux, args) + print(f"Training successful. Saved final weights to {args.adapter_file}.") diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py index b1122d75..3dd423b7 100644 --- a/flux/flux/__init__.py +++ b/flux/flux/__init__.py @@ -12,4 +12,5 @@ from .utils import ( load_flow_model, load_t5, load_t5_tokenizer, + save_config, ) diff --git a/flux/flux/utils.py b/flux/flux/utils.py index 21db17d3..2437f21f 100644 --- a/flux/flux/utils.py +++ b/flux/flux/utils.py @@ -3,7 +3,8 @@ import json import os from dataclasses import dataclass -from typing import Optional +from pathlib import Path +from typing import Optional, Union import mlx.core as mx from huggingface_hub import hf_hub_download @@ -207,3 +208,23 @@ def load_clip_tokenizer(name: str): def load_t5_tokenizer(name: str, pad: bool = True): model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model") return T5Tokenizer(model_file, 256 if "schnell" in name else 512) + + +def save_config( + config: dict, + config_path: Union[str, Path], +) -> None: + """Save the model configuration to the ``config_path``. + + The final configuration will be sorted before saving for better readability. + + Args: + config (dict): The model configuration. + config_path (Union[str, Path]): Model configuration file path. + """ + # Sort the config for better readability + config = dict(sorted(config.items())) + + # Write the config to the provided file + with open(config_path, "w") as fid: + json.dump(config, fid, indent=4) diff --git a/llms/README.md b/llms/README.md index 20863041..eeb3ed6a 100644 --- a/llms/README.md +++ b/llms/README.md @@ -101,7 +101,8 @@ To see a description of all the arguments you can do: #### Streaming For streaming generation, use the `stream_generate` function. This returns a -generator object which streams the output text. For example, +generator object which streams the output text, token, and log probabilities. +For example, ```python from mlx_lm import load, stream_generate @@ -116,7 +117,7 @@ prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) -for t in stream_generate(model, tokenizer, prompt, max_tokens=512): +for text, *_ in stream_generate(model, tokenizer, prompt, max_tokens=512): print(t, end="", flush=True) print() ``` @@ -221,6 +222,7 @@ Here are a few examples of Hugging Face models that work with this example: - [pfnet/plamo-13b-instruct](https://huggingface.co/pfnet/plamo-13b-instruct) - [stabilityai/stablelm-2-zephyr-1_6b](https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b) - [internlm/internlm2-7b](https://huggingface.co/internlm/internlm2-7b) +- [tiiuae/falcon-mamba-7b-instruct](https://huggingface.co/tiiuae/falcon-mamba-7b-instruct) Most [Mistral](https://huggingface.co/models?library=transformers,safetensors&other=mistral&sort=trending), @@ -248,3 +250,28 @@ model, tokenizer = load( tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True}, ) ``` + +### Large Models + +> [!NOTE] + This requires macOS 15.0 or higher to work. + +Models which are large relative to the total RAM available on the machine can +be slow. `mlx-lm` will attempt to make them faster by wiring the memory +occupied by the model and cache. This requires macOS 15 or higher to +work. + +If you see the following warning message: + +> [WARNING] Generating with a model that requires ... + +then the model will likely be slow on the given machine. If the model fits in +RAM then it can often be sped up by increasing the system wired memory limit. +To increase the limit, set the following `sysctl`: + +```bash +sudo sysctl iogpu.wired_limit_mb=N +``` + +The value `N` should be larger than the size of the model in megabytes but +smaller than the memory size of the machine. diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 70239db6..3811616f 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.19.1" +__version__ = "0.19.3" diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 04e75a3e..987b640d 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -8,7 +8,9 @@ import time import mlx.core as mx from .models.cache import make_prompt_cache, save_prompt_cache -from .utils import load +from .utils import load, maybe_quantize_kv_cache + +DEFAULT_QUANTIZED_KV_START = 5000 def setup_arg_parser(): @@ -70,6 +72,26 @@ def setup_arg_parser(): required=True, help="Message to be processed by the model ('-' reads from stdin)", ) + parser.add_argument( + "--kv-bits", + type=int, + help="Number of bits for KV cache quantization. " + "Defaults to no quantization.", + default=None, + ) + parser.add_argument( + "--kv-group-size", + type=int, + help="Group size for KV cache quantization.", + default=64, + ) + parser.add_argument( + "--quantized-kv-start", + help="When --kv-bits is set, start quantizing the KV cache " + "from this step onwards.", + type=int, + default=DEFAULT_QUANTIZED_KV_START, + ) return parser @@ -127,8 +149,10 @@ def main(): start = time.time() max_msg_len = 0 while y.size > 0: + model(y[:step_size][None], cache=cache) mx.eval([c.state for c in cache]) + mx.metal.clear_cache() processed += min(y.size, step_size) y = y[step_size:] current = time.time() @@ -136,15 +160,19 @@ def main(): msg = f"\rProcessed {processed:6d} tokens ({speed:6.2f} tok/s)" max_msg_len = max(max_msg_len, len(msg)) print(msg + " " * (max_msg_len - len(msg)), end="", flush=True) + + maybe_quantize_kv_cache( + cache, args.quantized_kv_start, args.kv_group_size, args.kv_bits + ) + print() - print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") + print(f"Peak memory: {mx.metal.get_peak_memory() / 1e9:.3f} GB") print("Saving...") metadata = {} metadata["model"] = args.model metadata["chat_template"] = tokenizer.chat_template metadata["tokenizer_config"] = json.dumps(tokenizer_config) - print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") save_prompt_cache(args.prompt_cache_file, cache, metadata) diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 7968a868..c03056a6 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -11,6 +11,7 @@ from .utils import load, stream_generate DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 +DEFAULT_MAX_TOKENS = 256 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" @@ -41,6 +42,13 @@ def setup_arg_parser(): help="Set the maximum key-value cache size", default=None, ) + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=DEFAULT_MAX_TOKENS, + help="Maximum number of tokens to generate", + ) return parser @@ -56,7 +64,7 @@ def main(): tokenizer_config={"trust_remote_code": True}, ) - print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.") + print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") prompt_cache = make_prompt_cache(model, args.max_kv_size) while True: query = input(">> ") @@ -66,10 +74,11 @@ def main(): prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) - for response in stream_generate( + for response, *_ in stream_generate( model, tokenizer, prompt, + args.max_tokens, temp=args.temp, top_p=args.top_p, prompt_cache=prompt_cache, diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0bf98ab2..51169def 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -6,15 +6,18 @@ import sys import mlx.core as mx -from .models.cache import load_prompt_cache +from .models.cache import QuantizedKVCache, load_prompt_cache from .utils import generate, load DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 +DEFAULT_MIN_P = 0.0 +DEFAULT_MIN_TOKENS_TO_KEEP = 1 DEFAULT_SEED = 0 DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" +DEFAULT_QUANTIZED_KV_START = 5000 def str2bool(string): @@ -51,6 +54,7 @@ def setup_arg_parser(): ) parser.add_argument( "--prompt", + "-p", default=DEFAULT_PROMPT, help="Message to be processed by the model ('-' reads from stdin)", ) @@ -67,6 +71,15 @@ def setup_arg_parser(): parser.add_argument( "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" ) + parser.add_argument( + "--min-p", type=float, default=DEFAULT_MIN_P, help="Sampling min-p" + ) + parser.add_argument( + "--min-tokens-to-keep", + type=float, + default=DEFAULT_MIN_TOKENS_TO_KEEP, + help="Minimum tokens to keep for min-p sampling.", + ) parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") parser.add_argument( "--ignore-chat-template", @@ -89,12 +102,6 @@ def setup_arg_parser(): action="store_true", help="Colorize output based on T[0] probability", ) - parser.add_argument( - "--cache-limit-gb", - type=int, - default=None, - help="Set the MLX cache limit in GB", - ) parser.add_argument( "--max-kv-size", type=int, @@ -107,6 +114,26 @@ def setup_arg_parser(): default=None, help="A file containing saved KV caches to avoid recomputing them", ) + parser.add_argument( + "--kv-bits", + type=int, + help="Number of bits for KV cache quantization. " + "Defaults to no quantization.", + default=None, + ) + parser.add_argument( + "--kv-group-size", + type=int, + help="Group size for KV cache quantization.", + default=64, + ) + parser.add_argument( + "--quantized-kv-start", + help="When --kv-bits is set, start quantizing the KV cache " + "from this step onwards.", + type=int, + default=DEFAULT_QUANTIZED_KV_START, + ) return parser @@ -143,15 +170,22 @@ def main(): mx.random.seed(args.seed) - if args.cache_limit_gb is not None: - mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - # Load the prompt cache and metadata if a cache file is provided using_cache = args.prompt_cache_file is not None if using_cache: prompt_cache, metadata = load_prompt_cache( - args.prompt_cache_file, return_metadata=True + args.prompt_cache_file, + return_metadata=True, ) + if isinstance(prompt_cache[0], QuantizedKVCache): + if args.kv_bits is not None and args.kv_bits != prompt_cache[0].bits: + raise ValueError( + "--kv-bits does not match the kv cache loaded from --prompt-cache-file." + ) + if args.kv_group_size != prompt_cache[0].group_size: + raise ValueError( + "--kv-group-size does not match the kv cache loaded from --prompt-cache-file." + ) # Building tokenizer_config tokenizer_config = ( @@ -225,8 +259,13 @@ def main(): formatter=formatter, temp=args.temp, top_p=args.top_p, + min_p=args.min_p, + min_tokens_to_keep=args.min_tokens_to_keep, max_kv_size=args.max_kv_size, prompt_cache=prompt_cache if using_cache else None, + kv_bits=args.kv_bits, + kv_group_size=args.kv_group_size, + quantized_kv_start=args.quantized_kv_start, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 3628a808..f02f49b1 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -5,6 +5,9 @@ from dataclasses import dataclass from typing import Any, Optional import mlx.core as mx +from mlx.utils import tree_map + +from .cache import QuantizedKVCache @dataclass @@ -39,7 +42,7 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): if cache is not None and cache[0] is not None: c = cache[0] if hasattr(c, "max_size"): - offset = min(c.max_size - 1, c.offset) + offset = min(c.max_size, c.offset) window_size = c.max_size else: offset = c.offset @@ -48,3 +51,63 @@ def create_attention_mask(h: mx.array, cache: Optional[Any] = None): else: mask = None return mask + + +def quantized_scaled_dot_product_attention( + queries: mx.array, + q_keys: tuple[mx.array, mx.array, mx.array], + q_values: tuple[mx.array, mx.array, mx.array], + scale: float, + mask: Optional[mx.array], + group_size: int = 64, + bits: int = 8, +) -> mx.array: + B, n_q_heads, L, D = queries.shape + n_kv_heads = q_keys[0].shape[-3] + n_repeats = n_q_heads // n_kv_heads + + queries *= scale + + if n_repeats > 1: + queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D)) + q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys) + q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values) + + scores = mx.quantized_matmul( + queries, *q_keys, transpose=True, group_size=group_size, bits=bits + ) + if mask is not None: + scores += mask + scores = mx.softmax(scores, axis=-1, precise=True) + out = mx.quantized_matmul( + scores, *q_values, transpose=False, group_size=group_size, bits=bits + ) + + if n_repeats > 1: + out = mx.reshape(out, (B, n_q_heads, L, D)) + + return out + + +def scaled_dot_product_attention( + queries, + keys, + values, + cache, + scale: float, + mask: Optional[mx.array], +) -> mx.array: + if isinstance(cache, QuantizedKVCache): + return quantized_scaled_dot_product_attention( + queries, + keys, + values, + scale=scale, + mask=mask, + group_size=cache.group_size, + bits=cache.bits, + ) + else: + return mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=scale, mask=mask + ) diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 84fadd06..df6784b5 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional import mlx.core as mx import mlx.nn as nn -from mlx.utils import tree_flatten, tree_unflatten +from mlx.utils import tree_flatten, tree_map, tree_unflatten -def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]: +def make_prompt_cache( + model: nn.Module, + max_kv_size: Optional[int] = None, +) -> List[Any]: """ Construct the model's cache for use when cgeneration. @@ -126,6 +129,88 @@ class _BaseCache: return False +class QuantizedKVCache(_BaseCache): + def __init__(self, group_size: int = 64, bits: int = 8): + self.keys = None + self.values = None + self.offset = 0 + self.step = 256 + self.group_size = group_size + self.bits = bits + + def update_and_fetch(self, keys, values): + B, n_kv_heads, num_steps, k_head_dim = keys.shape + v_head_dim = values.shape[-1] + prev = self.offset + + if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]: + el_per_int = 8 * mx.uint32.size // self.bits + new_steps = (self.step + num_steps - 1) // self.step * self.step + shape = (B, n_kv_heads, new_steps) + + def init_quant(dim): + return ( + mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32), + mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), + mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), + ) + + def expand_quant(x): + new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype) + return mx.concatenate([x, new_x], axis=-2) + + if self.keys is not None: + if prev % self.step != 0: + self.keys, self.values = tree_map( + lambda x: x[..., :prev, :], (self.keys, self.values) + ) + + self.keys, self.values = tree_map( + expand_quant, (self.keys, self.values) + ) + else: + self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim) + + self.offset += num_steps + + keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits) + values = mx.quantize(values, group_size=self.group_size, bits=self.bits) + for i in range(len(self.keys)): + self.keys[i][..., prev : self.offset, :] = keys[i] + self.values[i][..., prev : self.offset, :] = values[i] + + return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values)) + + @property + def state(self): + if self.offset == self.keys[0].shape[2]: + return self.keys, self.values + else: + return tree_map( + lambda x: x[..., : self.offset, :], (self.keys, self.values) + ) + + @state.setter + def state(self, v): + self.keys, self.values = v + + @property + def meta_state(self): + return tuple(map(str, (self.step, self.offset, self.group_size, self.bits))) + + @meta_state.setter + def meta_state(self, v): + self.step, self.offset, self.group_size, self.bits = map(int, v) + + def is_trimmable(self): + return True + + def trim(self, n): + n = min(self.offset, n) + self.offset -= n + return n + + class KVCache(_BaseCache): def __init__(self): self.keys = None @@ -180,6 +265,16 @@ class KVCache(_BaseCache): self.offset -= n return n + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: + quant_cache = QuantizedKVCache(group_size=group_size, bits=bits) + quant_cache.offset = self.offset + if self.keys is not None: + quant_cache.keys = mx.quantize(self.keys, group_size=group_size, bits=bits) + quant_cache.values = mx.quantize( + self.values, group_size=group_size, bits=bits + ) + return quant_cache + class RotatingKVCache(_BaseCache): @@ -230,9 +325,9 @@ class RotatingKVCache(_BaseCache): self.keys = self._temporal_order(self.keys) self.values = self._temporal_order(self.values) - # The largest size is self.max_size + S - 1 to ensure + # The largest size is self.max_size + S to ensure # every token gets at least self.max_size context - trim_size = self._idx - self.max_size + 1 + trim_size = self._idx - self.max_size self.keys = self._trim(trim_size, self.keys, keys) self.values = self._trim(trim_size, self.values, values) self.offset += keys.shape[2] @@ -320,6 +415,9 @@ class RotatingKVCache(_BaseCache): self._idx -= n return n + def to_quantized(self, group_size: int = 64, bits: int = 4) -> QuantizedKVCache: + raise NotImplementedError("RotatingKVCache Quantization NYI") + class MambaCache: def __init__(self): diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index 057c816d..7e002b0c 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -93,8 +93,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index 3b7e83d7..7be274cc 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -74,8 +74,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.out_proj(output) diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py index 03cb3b1a..b7b24dba 100644 --- a/llms/mlx_lm/models/deepseek.py +++ b/llms/mlx_lm/models/deepseek.py @@ -4,7 +4,7 @@ from typing import Any, Dict, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -97,8 +97,8 @@ class DeepseekAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index bb3e5184..444813b9 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -235,8 +235,8 @@ class DeepseekV2Attention(nn.Module): queries = mx.concatenate([q_nope, q_pe], axis=-1) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index 61de781e..3f384c3f 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -6,7 +6,7 @@ from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -79,8 +79,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index 97d9a8ff..52076a34 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -61,8 +61,8 @@ class Attention(nn.Module): if cache is not None: keys, values = cache.update_and_fetch(keys, values) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 068046ea..23e86e20 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -74,8 +74,8 @@ class Attention(nn.Module): if cache is not None: keys, values = cache.update_and_fetch(keys, values) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.c_proj(output) diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index 9f662491..ccb0b28b 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention # Based on the transformers implementation at: # https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -79,8 +79,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index 5264cb57..f5ce057e 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -141,8 +141,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.wo(output) diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 7da6b333..438278e5 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -190,9 +190,10 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 84f498e9..f2414660 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -23,6 +23,8 @@ class ModelArgs(BaseModelArgs): use_conv_bias: bool time_step_rank: int tie_word_embeddings: bool = True + use_bcdt_rms: bool = False + mixer_rms_eps: float = 1e-6 def __post_init__(self): if not hasattr(self, "hidden_size") and hasattr(self, "d_model"): @@ -44,6 +46,8 @@ class ModelArgs(BaseModelArgs): if self.time_step_rank == "auto": self.time_step_rank = math.ceil(self.hidden_size / 16) + if self.model_type == "falcon_mamba": + self.use_bcdt_rms = True class DepthWiseConv1d(nn.Module): @@ -83,6 +87,11 @@ class MambaBlock(nn.Module): self.intermediate_size = args.intermediate_size self.time_step_rank = int(args.time_step_rank) self.use_conv_bias = args.use_conv_bias + self.use_bcdt_rms = args.use_bcdt_rms + if self.use_bcdt_rms: + self.mixer_norm = lambda x: mx.fast.rms_norm( + x, mx.ones(x.shape[-1], x.dtype), eps=args.mixer_rms_eps + ) self.in_proj = nn.Linear( self.hidden_size, self.intermediate_size * 2, bias=args.use_bias @@ -126,6 +135,8 @@ class MambaBlock(nn.Module): ], axis=-1, ) + if self.use_bcdt_rms: + delta, B, C = map(self.mixer_norm, (delta, B, C)) delta = nn.softplus(self.dt_proj(delta)) new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1) if state is not None: diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index 4ac3c3b4..907beb2a 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -105,8 +105,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - attn_output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + attn_output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 20944fe3..dd94d1f4 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -87,8 +87,8 @@ class MixtralAttention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py index 3ea06e27..f73c0277 100644 --- a/llms/mlx_lm/models/nemotron.py +++ b/llms/mlx_lm/models/nemotron.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -113,8 +113,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 090e21c6..408802f4 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -107,8 +107,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 56b383b2..510025ea 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -7,7 +7,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -93,8 +93,13 @@ class PhiAttention(nn.Module): keys = self.rope(keys) scale = math.sqrt(1 / queries.shape[-1]) - output = mx.fast.scaled_dot_product_attention( - queries.astype(mx.float32), keys, values, scale=scale, mask=mask + output = scaled_dot_product_attention( + queries.astype(mx.float32), + keys, + values, + cache=cache, + scale=scale, + mask=mask, ).astype(values.dtype) output = output.moveaxis(2, 1).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index 9ef76f04..ee6efc49 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .su_rope import SuScaledRotaryEmbedding @@ -107,8 +107,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index 6b0759b4..53e1a638 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -8,7 +8,7 @@ from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -188,8 +188,8 @@ class Attention(nn.Module): queries, keys, values, scale=self.scale, mask=mask ) else: - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.dense(output) diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py index ca20a388..f42a6dd0 100644 --- a/llms/mlx_lm/models/phimoe.py +++ b/llms/mlx_lm/models/phimoe.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .su_rope import SuScaledRotaryEmbedding from .switch_layers import SwitchGLU @@ -79,8 +79,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index 865d0d8e..42d647b0 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -8,7 +8,7 @@ from typing import Tuple import mlx.core as mx import mlx.nn as nn -from .base import create_attention_mask +from .base import create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchMLP @@ -71,8 +71,13 @@ class RoPEAttention(nn.Module): # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) - output = mx.fast.scaled_dot_product_attention( - queries.astype(mx.float32), keys, values, scale=scale, mask=mask + output = scaled_dot_product_attention( + queries.astype(mx.float32), + keys, + values, + cache=cache, + scale=scale, + mask=mask, ).astype(values.dtype) output = output.moveaxis(2, 1).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index b0fd1a6c..c8e5bf50 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -7,7 +7,7 @@ import mlx.core as mx import mlx.nn as nn import numpy as np -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -92,10 +92,11 @@ class Attention(nn.Module): keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1]) values = mx.tile(values, [1, self.config.n_shared_head, 1, 1]) - output = mx.fast.scaled_dot_product_attention( + output = scaled_dot_product_attention( queries, keys, values, + cache=cache, scale=self.scale, mask=attention_mask, ) diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 2b69d5ec..8145a890 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -5,7 +5,7 @@ from dataclasses import dataclass import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -64,8 +64,8 @@ class Attention(nn.Module): queries = self.rotary_emb(queries) keys = self.rotary_emb(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index 4e7858de..fac59d78 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -6,7 +6,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -89,8 +89,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index d199116f..167fc5dd 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -7,7 +7,7 @@ from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .switch_layers import SwitchGLU @@ -89,8 +89,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 5595d311..49e4bb8f 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -7,7 +7,7 @@ from typing import List, Literal, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention from .cache import MambaCache, RotatingKVCache @@ -263,8 +263,8 @@ class LocalAttentionBlock(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 11202b02..482bb324 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -6,7 +6,7 @@ from dataclasses import dataclass import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -120,8 +120,8 @@ class Attention(nn.Module): # Finally perform the attention computation scale = math.sqrt(1 / queries.shape[-1]) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=scale, mask=mask ).astype(values.dtype) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) return self.o_proj(output) diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index ce0a2ec5..d7e626f2 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -6,7 +6,7 @@ from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention @dataclass @@ -64,8 +64,8 @@ class Attention(nn.Module): queries = self.rope(queries) keys = self.rope(keys) - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 814c03cc..48012863 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.17.0 +mlx>=0.19.2 numpy transformers[sentencepiece]>=4.39.3 protobuf diff --git a/llms/mlx_lm/sample_utils.py b/llms/mlx_lm/sample_utils.py index 20b008fa..c27b52d8 100644 --- a/llms/mlx_lm/sample_utils.py +++ b/llms/mlx_lm/sample_utils.py @@ -1,10 +1,83 @@ # Copyright © 2023-2024 Apple Inc. from functools import partial +from typing import Callable, Dict, Optional import mlx.core as mx +def make_sampler( + temp: float = 0.0, + top_p: float = 0.0, + min_p: float = 0.0, + min_tokens_to_keep: int = 1, +) -> Callable[mx.array, mx.array]: + """ + Make a sampler function for use with ``generate_step``. + + Args: + temp (float): The temperature for sampling, if 0 the argmax is used. + Default: ``0``. + top_p (float, optional): Nulceus sampling, higher means model considers + more less likely words. + min_p (float, optional): The minimum value (scaled by the top token's + probability) that a token probability must have to be considered. + min_tokens_to_keep (int, optional): Minimum number of tokens that cannot + be filtered by min_p sampling. + + Returns: + Callable[mx.array, mx.array]: + A sampler which takes log-probabilities and returns tokens. + """ + if temp == 0: + return lambda x: mx.argmax(x, axis=-1) + elif top_p > 0 and top_p < 1.0: + return lambda x: top_p_sampling(x, top_p, temp) + elif min_p != 0.0: + return lambda x: min_p_sampling(x, min_p, min_tokens_to_keep, temp) + else: + return lambda x: categorical_sampling(x, temp) + + +def make_logits_processors( + logit_bias: Optional[Dict[int, float]] = None, + repetition_penalty: Optional[float] = None, + repetition_context_size: Optional[int] = 20, +): + """ + Make logits processors for use with ``generate_step``. + + Args: + repetition_penalty (float, optional): The penalty factor for repeating + tokens. + repetition_context_size (int, optional): The number of tokens to + consider for repetition penalty. Default: ``20``. + logit_bias (dictionary, optional): Additive logit bias. + + Returns: + List[Callable[[mx.array, mx.array], mx.array]]: + A list of logits processors. Each processor in the list is a + callable which takes an array of tokens and an array of logits + and returns the updated logits. + """ + logits_processors = [] + if logit_bias: + indices = mx.array(list(logit_bias.keys())) + values = mx.array(list(logit_bias.values())) + + def logit_bias_processor(_, logits): + logits[:, indices] += values + return logits + + logits_processors.append(logit_bias_processor) + + if repetition_penalty and repetition_penalty != 0.0: + logits_processors.append( + make_repetition_penalty(repetition_penalty, repetition_context_size) + ) + return logits_processors + + @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def min_p_sampling( logits: mx.array, @@ -100,3 +173,36 @@ def top_p_sampling(logits: mx.array, top_p: float, temperature: float) -> mx.arr @partial(mx.compile, inputs=mx.random.state, outputs=mx.random.state) def categorical_sampling(logits, temp): return mx.random.categorical(logits * (1 / temp)) + + +def make_repetition_penalty(penalty: float, context_size: int = 20): + """ + Make repetition penalty processor. + + Paper: https://arxiv.org/abs/1909.05858 + + Args: + penalty (float): The repetition penalty factor to be applied. + context_size (int): The number of previous tokens to use. + Default: ``20``. + + Returns: + Callable[[mx.array, List[int]], mx.array]: + The repetition penalty processor. + """ + if penalty < 0 or not isinstance(penalty, float): + raise ValueError(f"penalty must be a non-negative float, got {penalty}") + + def repetition_penalty_processor(tokens, logits): + if len(tokens) > 0: + tokens = tokens[-context_size:] + selected_logits = logits[:, tokens] + selected_logits = mx.where( + selected_logits < 0, + selected_logits * penalty, + selected_logits / penalty, + ) + logits[:, tokens] = selected_logits + return logits + + return repetition_penalty_processor diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index ec659969..c1365b36 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -27,7 +27,7 @@ from huggingface_hub import scan_cache_dir from ._version import __version__ from .models.cache import make_prompt_cache -from .utils import generate_step, load +from .utils import load, stream_generate def get_system_fingerprint(): @@ -64,7 +64,7 @@ def stopping_criteria( end if it has (`trim_length`). """ if tokens and tokens[-1] == eos_token_id: - return StopCondition(stop_met=True, trim_length=1) + return StopCondition(stop_met=True, trim_length=0) for stop_ids in stop_id_sequences: if len(tokens) >= len(stop_ids): @@ -253,7 +253,7 @@ class APIHandler(BaseHTTPRequestHandler): self.max_tokens = self.body.get("max_completion_tokens", None) if self.max_tokens is None: self.max_tokens = self.body.get("max_tokens", 512) - self.temperature = self.body.get("temperature", 1.0) + self.temperature = self.body.get("temperature", 0.0) self.top_p = self.body.get("top_p", 1.0) self.repetition_penalty = self.body.get("repetition_penalty", 1.0) self.repetition_context_size = self.body.get("repetition_context_size", 20) @@ -290,10 +290,7 @@ class APIHandler(BaseHTTPRequestHandler): # Call endpoint specific method prompt = endpoints[self.path]() - - # Call method based on response type - method = self.handle_stream if self.stream else self.handle_completion - method(prompt, stop_id_sequences) + self.handle_completion(prompt, stop_id_sequences) def validate_model_parameters(self): """ @@ -452,32 +449,40 @@ class APIHandler(BaseHTTPRequestHandler): stop_id_sequences (List[List[int]]): A list of stop words passed to the stopping_criteria function """ - detokenizer = self.tokenizer.detokenizer - detokenizer.reset() tokens = [] finish_reason = "length" stop_sequence_suffix = None - logging.debug(f"Starting completion:") + if self.stream: + self.end_headers() + logging.debug(f"Starting stream:") + else: + logging.debug(f"Starting completion:") token_logprobs = [] top_tokens = [] prompt = self.get_prompt_cache(prompt) - for _, (token, logprobs) in zip( - range(self.max_tokens), - generate_step( - prompt=mx.array(prompt), + text = "" + tic = time.perf_counter() + for n, (segment, token, logprobs) in enumerate( + stream_generate( model=self.model, + tokenizer=self.tokenizer, + prompt=prompt, + max_tokens=self.max_tokens, temp=self.temperature, - top_p=self.top_p, repetition_penalty=self.repetition_penalty, repetition_context_size=self.repetition_context_size, logit_bias=self.logit_bias, prompt_cache=self.prompt_cache.cache, ), ): - detokenizer.add_token(token) - logging.debug(detokenizer.text) + if n == 0: + prompt_time = time.perf_counter() - tic + tic = time.perf_counter() + + text += segment + logging.debug(text) tokens.append(token) if self.logprobs > 0: @@ -498,121 +503,63 @@ class APIHandler(BaseHTTPRequestHandler): stop_sequence_suffix = self.tokenizer.decode( tokens[-stop_condition.trim_length :] ) + text = text[: -len(stop_sequence_suffix)] break - self.prompt_cache.tokens.extend(tokens) - detokenizer.finalize() - text = ( - detokenizer.text - if stop_sequence_suffix is None - else detokenizer.text[: -len(stop_sequence_suffix)] - ) - response = self.generate_response( - text, - finish_reason, - len(prompt), - len(tokens), - token_logprobs=token_logprobs, - top_tokens=top_tokens, - tokens=tokens, - ) - - response_json = json.dumps(response).encode() - indent = "\t" # Backslashes can't be inside of f-strings - logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") - - # Send an additional Content-Length header when it is known - self.send_header("Content-Length", str(len(response_json))) - self.end_headers() - - self.wfile.write(response_json) - self.wfile.flush() - - def handle_stream( - self, - prompt: List[int], - stop_id_sequences: List[List[int]], - ): - """ - Generate response to prompt and foward it to the client using a Server - Sent Events (SSE) stream. - - Args: - prompt (mx.array): The tokenized prompt - stop_id_sequences (List[List[int]]): A list of stop words passed to - the stopping_criteria function - """ - # No additional headers are needed, call end_headers - self.end_headers() - - detokenizer = self.tokenizer.detokenizer - detokenizer.reset() - tokens = [] - - stop_sequence_suffix = None - logging.debug(f"Starting stream:") - - prompt = self.get_prompt_cache(prompt) - - for _, (token, _) in zip( - range(self.max_tokens), - generate_step( - prompt=mx.array(prompt), - model=self.model, - temp=self.temperature, - top_p=self.top_p, - repetition_penalty=self.repetition_penalty, - repetition_context_size=self.repetition_context_size, - prompt_cache=self.prompt_cache.cache, - ), - ): - detokenizer.add_token(token) - logging.debug(detokenizer.text) - tokens.append(token) - - stop_condition = stopping_criteria( - tokens, - stop_id_sequences, - self.tokenizer.eos_token_id, - ) - if stop_condition.stop_met: - if stop_condition.trim_length: - stop_sequence_suffix = self.tokenizer.decode( - tokens[-stop_condition.trim_length :] + if self.stream: + # If the end of tokens overlaps with a stop sequence, generate new + # tokens until we know if the stop sequence is hit or not + if any( + ( + sequence_overlap(tokens, sequence) + for sequence in stop_id_sequences ) - break - - # If the end of tokens overlaps with a stop sequence, generate new - # tokens until we know if the stop sequence is hit or not - if any( - (sequence_overlap(tokens, sequence) for sequence in stop_id_sequences) - ): - continue - - new_text = detokenizer.last_segment - if new_text: - response = self.generate_response(new_text, None) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - self.wfile.flush() + ): + continue + elif segment: + response = self.generate_response(segment, None) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() self.prompt_cache.tokens.extend(tokens) - # check is there any remaining text to send - detokenizer.finalize() - last_segment = detokenizer.last_segment - if last_segment: - if stop_sequence_suffix is not None: - last_segment = last_segment[: -len(stop_sequence_suffix)] - response = self.generate_response(last_segment, "length") + gen_time = time.perf_counter() - tic + prompt_tps = len(prompt) / prompt_time + gen_tps = len(tokens) / gen_time + peak_mem = mx.metal.get_peak_memory() / 1e9 + logging.debug(f"Prompt: {prompt_tps:.3f} tokens-per-sec") + logging.debug(f"Generation: {gen_tps:.3f} tokens-per-sec") + logging.debug(f"Peak memory: {peak_mem:.3f} GB") + + if self.stream: + response = self.generate_response(segment, finish_reason) self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) self.wfile.flush() + if self.stream_options is not None and self.stream_options["include_usage"]: + response = self.completion_usage_response(len(prompt), len(tokens)) + self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) + self.wfile.flush() + self.wfile.write("data: [DONE]\n\n".encode()) + self.wfile.flush() + else: + response = self.generate_response( + text, + finish_reason, + len(prompt), + len(tokens), + token_logprobs=token_logprobs, + top_tokens=top_tokens, + tokens=tokens, + ) + response_json = json.dumps(response).encode() + indent = "\t" # Backslashes can't be inside of f-strings + logging.debug(f"Outgoing Response: {json.dumps(response, indent=indent)}") - if self.stream_options is not None and self.stream_options["include_usage"]: - response = self.completion_usage_response(len(prompt), len(tokens)) - self.wfile.write(f"data: {json.dumps(response)}\n\n".encode()) - - self.wfile.write("data: [DONE]\n\n".encode()) - self.wfile.flush() + # Send an additional Content-Length header when it is known + self.send_header("Content-Length", str(len(response_json))) + self.end_headers() + self.wfile.write(response_json) + self.wfile.flush() def completion_usage_response( self, diff --git a/llms/mlx_lm/tokenizer_utils.py b/llms/mlx_lm/tokenizer_utils.py index 0cbc3b9b..9d390733 100644 --- a/llms/mlx_lm/tokenizer_utils.py +++ b/llms/mlx_lm/tokenizer_utils.py @@ -6,12 +6,6 @@ from transformers import AutoTokenizer REPLACEMENT_CHAR = "\ufffd" -def _remove_space(x): - if x and x[0] == " ": - return x[1:] - return x - - class StreamingDetokenizer: """The streaming detokenizer interface so that we can detokenize one token at a time. @@ -123,42 +117,42 @@ class SPMStreamingDetokenizer(StreamingDetokenizer): def __init__(self, tokenizer, trim_space=True): self.trim_space = trim_space + self._sep = "\u2581".encode() # Extract the tokens in a list from id to text self.tokenmap = [""] * (max(tokenizer.vocab.values()) + 1) for value, tokenid in tokenizer.vocab.items(): - self.tokenmap[tokenid] = value - - # Replace bytes with their value - for i in range(len(self.tokenmap)): - if self.tokenmap[i].startswith("<0x"): - self.tokenmap[i] = chr(int(self.tokenmap[i][3:5], 16)) + if value.startswith("<0x"): + # Replace bytes with their value + self.tokenmap[tokenid] = bytes([int(value[3:5], 16)]) + else: + self.tokenmap[tokenid] = value.encode() self.reset() def reset(self): self.offset = 0 - self._unflushed = "" + self._unflushed = b"" self.text = "" self.tokens = [] + def _flush(self): + text = self._unflushed.replace(self._sep, b" ").decode("utf-8") + if not self.text and self.trim_space and text and text[0] == " ": + text = text[1:] + self.text += text + def add_token(self, token): v = self.tokenmap[token] - if v[0] == "\u2581": - if self.text or not self.trim_space: - self.text += self._unflushed.replace("\u2581", " ") - else: - self.text = _remove_space(self._unflushed.replace("\u2581", " ")) + if v.startswith(self._sep): + self._flush() self._unflushed = v else: self._unflushed += v def finalize(self): - if self.text or not self.trim_space: - self.text += self._unflushed.replace("\u2581", " ") - else: - self.text = _remove_space(self._unflushed.replace("\u2581", " ")) - self._unflushed = "" + self._flush() + self._unflushed = b"" class BPEStreamingDetokenizer(StreamingDetokenizer): @@ -186,6 +180,8 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): # https://github.com/openai/gpt-2/blob/master/src/encoder.py self.make_byte_decoder() + self._added_ids = set(tokenizer.added_tokens_decoder.keys()) + def reset(self): self.offset = 0 self._unflushed = "" @@ -205,12 +201,17 @@ class BPEStreamingDetokenizer(StreamingDetokenizer): def add_token(self, token): v = self.tokenmap[token] - if self._byte_decoder[v[0]] == 32: + is_added = token in self._added_ids + if is_added or self._byte_decoder[v[0]] == 32: current_text = bytearray( self._byte_decoder[c] for c in self._unflushed ).decode("utf-8") self.text += self._maybe_trim_space(current_text) - self._unflushed = v + if is_added: + self.text += v + self._unflushed = "" + else: + self._unflushed = v else: self._unflushed += v diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 1d934a72..21b1af18 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -10,6 +10,7 @@ from typing import Union import mlx.core as mx import mlx.nn as nn import numpy as np +from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten @@ -84,9 +85,16 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) f" examples but only has {len(dataset)}." ) + # If running in distributed mode (N machines) then each one should skip N-1 + # samples + step = mx.distributed.init().size() + if batch_size % step != 0: + raise ValueError("The batch size must be divisible by the number of workers") + # Make the batches: batch_idx = [ - idx[i : i + batch_size] for i in range(0, len(idx) - batch_size + 1, batch_size) + idx[i : i + batch_size : step] + for i in range(0, len(idx) - batch_size + 1, batch_size) ] while True: @@ -112,9 +120,9 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) max_length_in_batch = min(max_length_in_batch, max_seq_length) - batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32) + batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - for j in range(batch_size): + for j in range(batch_size // step): truncated_length = min(lengths[j], max_seq_length) batch_arr[j, :truncated_length] = batch[j][:truncated_length] lengths[j] = ( @@ -138,7 +146,7 @@ def evaluate( loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): - all_losses = [] + all_losses = 0 ntokens = 0 index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) @@ -153,10 +161,14 @@ def evaluate( ), ): losses, toks = loss(model, *batch) - all_losses.append((losses * toks).item()) - ntokens += toks.item() + all_losses += losses * toks + ntokens += toks + mx.eval(all_losses, ntokens) - return np.sum(all_losses) / ntokens + all_losses = mx.distributed.all_sum(all_losses) + ntokens = mx.distributed.all_sum(ntokens) + + return (all_losses / ntokens).item() class TrainingCallback: @@ -182,6 +194,11 @@ def train( training_callback: TrainingCallback = None, ): print(f"Starting training..., iters: {args.iters}") + world = mx.distributed.init() + world_size = world.size() + rank = world.rank() + if world_size > 1: + print(f"Node {rank} of {world_size}") if args.grad_checkpoint: grad_checkpoint(model.layers[0]) @@ -192,6 +209,9 @@ def train( # Forward and backward pass (lvalue, toks), grad = loss_value_and_grad(model, *batch) + # All reduce the gradients if running in distributed mode + grad = average_gradients(grad) + # Model update optimizer.update(model, grad) @@ -199,8 +219,9 @@ def train( loss_value_and_grad = nn.value_and_grad(model, loss) - losses = [] + losses = 0 n_tokens = 0 + steps = 0 trained_tokens = 0 # Main training loop start = time.perf_counter() @@ -229,9 +250,13 @@ def train( iterate_batches=iterate_batches, ) val_time = time.perf_counter() - stop - print( - f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s" - ) + if rank == 0: + print( + f"Iter {it}: " + f"Val loss {val_loss:.3f}, " + f"Val took {val_time:.3f}s", + flush=True, + ) if training_callback is not None: val_info = { @@ -244,30 +269,33 @@ def train( start = time.perf_counter() lvalue, toks = step(batch) - mx.eval(state, lvalue, toks) - - # Record loss - losses.append(lvalue.item()) - n_tokens += toks.item() + losses += lvalue + n_tokens += toks + steps += 1 + mx.eval(state, losses, n_tokens) # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: stop = time.perf_counter() - train_loss = np.mean(losses) + train_loss = mx.distributed.all_sum(losses).item() + train_loss /= steps * mx.distributed.init().size() + n_tokens = mx.distributed.all_sum(n_tokens).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens - peak_mem = mx.metal.get_peak_memory() / 2**30 - print( - f"Iter {it}: Train loss {train_loss:.3f}, " - f"Learning Rate {learning_rate:.3e}, " - f"It/sec {it_sec:.3f}, " - f"Tokens/sec {tokens_sec:.3f}, " - f"Trained Tokens {trained_tokens}, " - f"Peak mem {peak_mem:.3f} GB" - ) + peak_mem = mx.metal.get_peak_memory() / 1e9 + if rank == 0: + print( + f"Iter {it}: Train loss {train_loss:.3f}, " + f"Learning Rate {learning_rate:.3e}, " + f"It/sec {it_sec:.3f}, " + f"Tokens/sec {tokens_sec:.3f}, " + f"Trained Tokens {trained_tokens}, " + f"Peak mem {peak_mem:.3f} GB", + flush=True, + ) if training_callback is not None: train_info = { @@ -281,8 +309,9 @@ def train( } training_callback.on_train_loss_report(train_info) - losses = [] + losses = 0 n_tokens = 0 + steps = 0 start = time.perf_counter() # Save adapter weights diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 92741b68..8893b570 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -1,5 +1,6 @@ # Copyright © 2023-2024 Apple Inc. +import contextlib import copy import glob import importlib @@ -14,12 +15,12 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download -from mlx.utils import tree_flatten +from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer # Local imports -from .models import base, cache -from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling +from .models import cache +from .sample_utils import make_logits_processors, make_sampler from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model from .tuner.utils import load_adapters @@ -28,10 +29,14 @@ from .tuner.utils import load_adapters MODEL_REMAPPING = { "mistral": "llama", # mistral is compatible with llama "phi-msft": "phixtral", + "falcon_mamba": "mamba", } MAX_FILE_SIZE_GB = 5 +# A stream on the default device just for generation +generation_stream = mx.new_stream(mx.default_device()) + class ModelNotFoundError(Exception): def __init__(self, message): @@ -39,6 +44,40 @@ class ModelNotFoundError(Exception): super().__init__(self.message) +@contextlib.contextmanager +def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): + """ + A context manager to temporarily change the wired limit. + + Note, the wired limit should not be changed during an async eval. If an + async eval could be running pass in the streams to synchronize with prior + to exiting the context manager. + """ + model_bytes = tree_reduce( + lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 + ) + max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"] + if model_bytes > 0.9 * max_rec_size: + model_mb = model_bytes // 2**20 + max_rec_mb = max_rec_size // 2**20 + print( + "[WARNING] Generating with a model that requires {model_mb} MB " + "which is close to the maximum recommended size of {max_rec_mb} " + "MB. This can be slow. See the documentation for possible work-arounds: " + "https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models" + ) + old_limit = mx.metal.set_wired_limit(max_rec_size) + try: + yield None + finally: + if streams is not None: + for s in streams: + mx.synchronize(s) + else: + mx.synchronize() + mx.metal.set_wired_limit(old_limit) + + def _get_classes(config: dict): """ Retrieve the model and model args classes based on the configuration. @@ -101,27 +140,16 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path return model_path -def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float): - """ - Apply repetition penalty to specific logits based on the given context. - - Paper: https://arxiv.org/abs/1909.05858 - - Args: - logits (mx.array): The logits produced by the language model. - tokens (mx.array): A list of N previous tokens. - penalty (float): The repetition penalty factor to be applied. - - Returns: - logits (mx.array): Logits with repetition penalty applied to generated tokens. - """ - if len(tokens) > 0: - selected_logits = logits[:, tokens] - selected_logits = mx.where( - selected_logits < 0, selected_logits * penalty, selected_logits / penalty - ) - logits[:, tokens] = selected_logits - return logits +def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_bits): + if ( + kv_bits is not None + and not isinstance(prompt_cache[0], cache.QuantizedKVCache) + and prompt_cache[0].offset > quantized_kv_start + ): + for i in range(len(prompt_cache)): + prompt_cache[i] = prompt_cache[i].to_quantized( + group_size=kv_group_size, bits=kv_bits + ) def generate_step( @@ -137,7 +165,10 @@ def generate_step( max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, logit_bias: Optional[Dict[int, float]] = None, - logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + kv_bits: Optional[int] = None, + kv_group_size: int = 64, + quantized_kv_start: int = 0, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -163,80 +194,56 @@ def generate_step( prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if provided, the cache will be updated in place. logit_bias (dictionary, optional): Additive logit bias. - logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed logits. Default: ``None``. + kv_bits (int, optional): Number of bits to use for KV cache quantization. + None implies no cache quantization. Default: ``None``. + kv_group_size (int): Group size for KV cache quantization. Default: ``64``. + quantized_kv_start (int): Step to begin using a quantized KV cache. + when ``kv_bits`` is non-None. Default: ``0``. Yields: - Generator[Tuple[mx.array, mx.array], None, None]: A generator producing - one token and a vector of log probabilities. + Tuple[mx.array, mx.array]: One token and a vector of log probabilities. """ - def sample(logits: mx.array) -> Tuple[mx.array, float]: - logprobs = logits - mx.logsumexp(logits) - - if temp == 0: - token = mx.argmax(logits, axis=-1) - else: - if top_p > 0 and top_p < 1.0: - token = top_p_sampling(logits, top_p, temp) - elif min_p != 0.0: - token = min_p_sampling(logits, min_p, min_tokens_to_keep, temp) - else: - token = categorical_sampling(logits, temp) - - return token, logprobs - - if repetition_penalty and ( - repetition_penalty < 0 or not isinstance(repetition_penalty, float) - ): - raise ValueError( - f"repetition_penalty must be a non-negative float, got {repetition_penalty}" - ) - - logits_processor = logits_processor or [] - - if repetition_penalty: - - def repetition_penalty_processor(tokens, logits): - return apply_repetition_penalty( - logits, tokens[-repetition_context_size:], repetition_penalty - ) - - logits_processor.append(repetition_penalty_processor) - - if logit_bias: - indices = mx.array(list(logit_bias.keys())) - values = mx.array(list(logit_bias.values())) - - def logit_bias_processor(_, logits): - logits[:, indices] += values - return logits - - logits_processor.append(logit_bias_processor) - y = prompt tokens = None # Create the KV cache for generation if prompt_cache is None: - prompt_cache = cache.make_prompt_cache(model, max_kv_size) + prompt_cache = cache.make_prompt_cache( + model, + max_kv_size=max_kv_size, + ) elif len(prompt_cache) != len(model.layers): raise ValueError("Wrong number of layers in the prompt cache.") + sampler = make_sampler(temp, top_p, min_p, min_tokens_to_keep) + logits_processors = logits_processors or [] + logits_processors.extend( + make_logits_processors(logit_bias, repetition_penalty, repetition_context_size) + ) + def _step(y): - logits = model(y[None], cache=prompt_cache) - logits = logits[:, -1, :] + with mx.stream(generation_stream): + logits = model(y[None], cache=prompt_cache) + logits = logits[:, -1, :] - if logits_processor: - nonlocal tokens - tokens = mx.concat([tokens, y]) if tokens is not None else y + if logits_processors: + nonlocal tokens + tokens = mx.concat([tokens, y]) if tokens is not None else y - for processor in logits_processor: - logits = processor(tokens, logits) + for processor in logits_processors: + logits = processor(tokens, logits) - y, logprobs = sample(logits) - return y, logprobs.squeeze(0) + maybe_quantize_kv_cache( + prompt_cache, quantized_kv_start, kv_group_size, kv_bits + ) + + logprobs = logits - mx.logsumexp(logits, keepdims=True) + y = sampler(logprobs) + return y, logprobs.squeeze(0) while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=prompt_cache) @@ -247,53 +254,65 @@ def generate_step( y, logprobs = _step(y) mx.async_eval(y, logprobs) + n = 0 while True: next_y, next_logprobs = _step(y) mx.async_eval(next_y, next_logprobs) yield y.item(), logprobs + if n % 256 == 0: + mx.metal.clear_cache() + n += 1 y, logprobs = next_y, next_logprobs def stream_generate( model: nn.Module, tokenizer: Union[PreTrainedTokenizer, TokenizerWrapper], - prompt: str, + prompt: Union[str, List[int]], max_tokens: int = 100, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> Generator[Tuple[str, int, mx.array], None, None]: """ A generator producing text based on the given prompt from the model. Args: - prompt (mx.array): The input prompt. model (nn.Module): The model to use for generation. - max_tokens (int): The ma + tokenizer (PreTrainedTokenizer): The tokenizer. + prompt (Union[str, List[int]]): The input prompt string or integer tokens. + max_tokens (int): The maximum number of tokens. Default: ``100``. kwargs: The remaining options get passed to :func:`generate_step`. See :func:`generate_step` for more details. Yields: - Generator[Tuple[mx.array, mx.array]]: A generator producing text. + Tuple[str, int, mx.array]: + The next text segment, token, and vector of log probabilities. """ if not isinstance(tokenizer, TokenizerWrapper): tokenizer = TokenizerWrapper(tokenizer) - prompt_tokens = mx.array(tokenizer.encode(prompt)) + prompt_tokens = mx.array( + prompt if isinstance(prompt, list) else tokenizer.encode(prompt) + ) detokenizer = tokenizer.detokenizer - detokenizer.reset() - for n, (token, _) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if token == tokenizer.eos_token_id: - break - detokenizer.add_token(token) + with wired_limit(model, [generation_stream]): + detokenizer.reset() + for n, (token, logits) in zip( + range(max_tokens), + generate_step(prompt_tokens, model, **kwargs), + ): + if token == tokenizer.eos_token_id: + break - # Yield the last segment if streaming - yield detokenizer.last_segment + detokenizer.add_token(token) - detokenizer.finalize() - yield detokenizer.last_segment + if n == (max_tokens - 1): + break + + yield detokenizer.last_segment, token, logits + + detokenizer.finalize() + yield detokenizer.last_segment, token, logits def generate( @@ -304,7 +323,7 @@ def generate( verbose: bool = False, formatter: Optional[Callable] = None, **kwargs, -) -> Union[str, Generator[str, None, None]]: +) -> str: """ Generate a complete response from the model. @@ -330,48 +349,49 @@ def generate( prompt_tokens = mx.array(tokenizer.encode(prompt)) detokenizer = tokenizer.detokenizer - tic = time.perf_counter() - detokenizer.reset() + with wired_limit(model, [generation_stream]): + tic = time.perf_counter() + detokenizer.reset() + for n, (token, logprobs) in zip( + range(max_tokens), + generate_step(prompt_tokens, model, **kwargs), + ): + if n == 0: + prompt_time = time.perf_counter() - tic + tic = time.perf_counter() + if token == tokenizer.eos_token_id: + break + detokenizer.add_token(token) - for n, (token, logprobs) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if n == 0: - prompt_time = time.perf_counter() - tic - tic = time.perf_counter() - if token == tokenizer.eos_token_id: - break - detokenizer.add_token(token) + if verbose: + if formatter: + # We have to finalize so that the prob corresponds to the last segment + detokenizer.finalize() + prob = mx.exp(logprobs[token]).item() + formatter(detokenizer.last_segment, prob) + else: + print(detokenizer.last_segment, end="", flush=True) + + token_count = n + 1 + detokenizer.finalize() if verbose: - if formatter: - # We have to finalize so that the prob corresponds to the last segment - detokenizer.finalize() - with mx.stream(mx.cpu): - prob = mx.exp(logprobs[token]).item() - formatter(detokenizer.last_segment, prob) - else: - print(detokenizer.last_segment, end="", flush=True) + gen_time = time.perf_counter() - tic + print(detokenizer.last_segment, flush=True) + print("=" * 10) + if token_count == 0: + print("No tokens generated for this prompt") + return + prompt_tps = prompt_tokens.size / prompt_time + gen_tps = (token_count - 1) / gen_time + print( + f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" + ) + print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") + peak_mem = mx.metal.get_peak_memory() / 1e9 + print(f"Peak memory: {peak_mem:.3f} GB") - token_count = n + 1 - detokenizer.finalize() - - if verbose: - gen_time = time.perf_counter() - tic - print(detokenizer.last_segment, flush=True) - print("=" * 10) - if token_count == 0: - print("No tokens generated for this prompt") - return - prompt_tps = prompt_tokens.size / prompt_time - gen_tps = (token_count - 1) / gen_time - print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec") - print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 2**30 - print(f"Peak memory: {peak_mem:.3f} GB") - - return detokenizer.text + return detokenizer.text def load_config(model_path: Path) -> dict: @@ -553,7 +573,9 @@ def upload_to_hub(path: str, upload_repo: str, hf_path: str): f""" # {upload_repo} - The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) using mlx-lm version **{__version__}**. + The Model [{upload_repo}](https://huggingface.co/{upload_repo}) was + converted to MLX format from [{hf_path}](https://huggingface.co/{hf_path}) + using mlx-lm version **{__version__}**. ## Use with mlx diff --git a/llms/tests/test_finetune.py b/llms/tests/test_finetune.py index 107be092..6ba81628 100644 --- a/llms/tests/test_finetune.py +++ b/llms/tests/test_finetune.py @@ -3,6 +3,7 @@ import math import sys import unittest +from contextlib import contextmanager from io import StringIO from unittest.mock import MagicMock @@ -17,6 +18,14 @@ from mlx_lm.tuner.trainer import evaluate from mlx_lm.tuner.utils import build_schedule +@contextmanager +def swapped_with_identity(obj, func): + old_func = getattr(obj, func) + setattr(obj, func, lambda x: x) + yield + setattr(obj, func, old_func) + + class TestLora(unittest.TestCase): def setUp(self): self.capturedOutput = StringIO() @@ -374,16 +383,17 @@ class TestScheduleConfig(unittest.TestCase): (MagicMock(return_value=0.4), MagicMock(return_value=180)), (MagicMock(return_value=0.6), MagicMock(return_value=120)), ] - evaluate( - model=mock_model, - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - num_batches=2, - max_seq_length=2048, - loss=mock_default_loss, - iterate_batches=mock_iterate_batches, - ) + with swapped_with_identity(mx.distributed, "all_sum"): + evaluate( + model=mock_model, + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + num_batches=2, + max_seq_length=2048, + loss=mock_default_loss, + iterate_batches=mock_iterate_batches, + ) mock_iterate_batches.assert_called_once_with( dataset=mock_dataset, @@ -412,16 +422,17 @@ class TestScheduleConfig(unittest.TestCase): (MagicMock(return_value=0.2), MagicMock(return_value=150)), ] - evaluate( - model=mock_model, - dataset=mock_dataset, - tokenizer=mock_tokenizer, - batch_size=2, - num_batches=-1, - max_seq_length=2048, - loss=mock_default_loss, - iterate_batches=mock_iterate_batches, - ) + with swapped_with_identity(mx.distributed, "all_sum"): + evaluate( + model=mock_model, + dataset=mock_dataset, + tokenizer=mock_tokenizer, + batch_size=2, + num_batches=-1, + max_seq_length=2048, + loss=mock_default_loss, + iterate_batches=mock_iterate_batches, + ) mock_iterate_batches.assert_called_once_with( dataset=mock_dataset, diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index 68f1670b..e0a372a9 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -46,7 +46,7 @@ class TestGenerate(unittest.TestCase): "hello", max_tokens=5, verbose=False, - logits_processor=[logits_processor], + logits_processors=[logits_processor], ) self.assertEqual(len(all_toks), len(init_toks) + 5) diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index 64cd9486..0867ab56 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -9,6 +9,7 @@ import mlx.core as mx from mlx_lm.models.cache import ( KVCache, MambaCache, + QuantizedKVCache, RotatingKVCache, load_prompt_cache, make_prompt_cache, @@ -186,6 +187,18 @@ class TestPromptCache(unittest.TestCase): num_trimmed = trim_prompt_cache(cache, 4) self.assertEqual(num_trimmed, 0) + cache = [QuantizedKVCache() for _ in range(2)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 64)) + c.update_and_fetch(x, x) + + num_trimmed = trim_prompt_cache(cache, 7) + self.assertEqual(num_trimmed, 7) + + # Trim more tokens than remain + num_trimmed = trim_prompt_cache(cache, 4) + self.assertEqual(num_trimmed, 3) + def test_trim_cache_with_generate(self): model, tokenizer = load(HF_MODEL_PATH) prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] @@ -238,6 +251,56 @@ class TestPromptCache(unittest.TestCase): self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y)) self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z)) + def test_save_load_quantized_cache(self): + cache = [QuantizedKVCache(bits=4, group_size=32) for _ in range(4)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 32)) + c.update_and_fetch(x, x) + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + self.assertTrue(loaded_cache[0].bits == cache[0].bits) + self.assertTrue(loaded_cache[0].group_size == cache[0].group_size) + self.assertTrue(len(cache), len(loaded_cache)) + for c, lc in zip(cache, loaded_cache): + self.assertEqual(c.offset, lc.offset) + # Loop over quantized tuple + for i in range(3): + self.assertTrue(mx.array_equal(c.state[0][i], lc.state[0][i])) + self.assertTrue(mx.array_equal(c.state[1][i], lc.state[1][i])) + + # Test with metadata + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + metadata = {"a": "b", "c": "d"} + save_prompt_cache(cache_file, cache, metadata) + _, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True) + self.assertEqual(metadata, loaded_metadata) + + def test_cache_to_quantized(self): + model, tokenizer = load(HF_MODEL_PATH) + prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] + results = zip(range(4), generate_step(prompt, model)) + toks, all_logits = zip(*(r[1] for r in results)) + + prompt_cache = make_prompt_cache(model) + i = 0 + for _, (tok, logits) in zip( + range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + ): + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i])) + i += 1 + + prompt_cache = [c.to_quantized(bits=8, group_size=32) for c in prompt_cache] + + for _, (tok, logits) in zip( + range(1), + generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + ): + i += 1 + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2)) + if __name__ == "__main__": unittest.main() diff --git a/llms/tests/test_tokenizers.py b/llms/tests/test_tokenizers.py index 03445c1f..9c30d51e 100644 --- a/llms/tests/test_tokenizers.py +++ b/llms/tests/test_tokenizers.py @@ -42,6 +42,9 @@ class TestTokenizers(unittest.TestCase): text += detokenizer.last_segment self.assertEqual(text, expected_text) + tokens = tokenizer.encode("こんにちは!私の名前はAI") + check(tokens) + tokens = tokenizer.encode("a ,b") check(tokens) @@ -74,6 +77,17 @@ class TestTokenizers(unittest.TestCase): tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer) self.check_tokenizer(tokenizer) + def test_special_tokens(self): + tokenizer_repo = "mlx-community/DeepSeek-Coder-V2-Lite-Instruct-4bit-mlx" + tokenizer = self.download_tokenizer(tokenizer_repo) + + detokenizer = tokenizer.detokenizer + detokenizer.reset() + detokenizer.add_token(tokenizer.eos_token_id) + detokenizer.finalize() + + self.assertEqual(detokenizer.last_segment, tokenizer.eos_token) + if __name__ == "__main__": unittest.main() diff --git a/whisper/README.md b/whisper/README.md index ac6e95f6..cd3bc684 100644 --- a/whisper/README.md +++ b/whisper/README.md @@ -25,7 +25,7 @@ pip install mlx-whisper At its simplest: -``` +```sh mlx_whisper audio_file.mp3 ``` @@ -35,6 +35,15 @@ Use `-f` to specify the output format and `--model` to specify the model. There are many other supported command line options. To see them all, run `mlx_whisper -h`. +You can also pipe the audio content of other programs via stdin: + +```sh +some-process | mlx_whisper - +``` + +The default output file name will be `content.*`. You can specify the name with +the `--output-name` flag. + #### API Transcribe audio with: @@ -103,7 +112,7 @@ python convert.py --help ``` By default, the conversion script will make the directory `mlx_models` -and save the converted `weights.npz` and `config.json` there. +and save the converted `weights.npz` and `config.json` there. Each time it is run, `convert.py` will overwrite any model in the provided path. To save different models, make sure to set `--mlx-path` to a unique diff --git a/whisper/convert.py b/whisper/convert.py index cdd50bc5..301fd5b4 100644 --- a/whisper/convert.py +++ b/whisper/convert.py @@ -181,7 +181,7 @@ def load_torch_weights_and_config( ) if name_or_path.endswith(".pt"): - checkpoint = torch.load(name_or_path, map_location="cpu") + checkpoint = torch.load(name_or_path, map_location="cpu", weights_only=False) weights, config = checkpoint["model_state_dict"], checkpoint["dims"] else: name_or_path = Path(name_or_path) @@ -387,7 +387,7 @@ if __name__ == "__main__": # Save weights print("[INFO] Saving") - np.savez(str(mlx_path / "weights.npz"), **weights) + mx.save_safetensors(str(mlx_path / "weights.safetensors"), weights) # Save config.json with model_type with open(str(mlx_path / "config.json"), "w") as f: diff --git a/whisper/mlx_whisper/_version.py b/whisper/mlx_whisper/_version.py index 67c7397c..8280e038 100644 --- a/whisper/mlx_whisper/_version.py +++ b/whisper/mlx_whisper/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.3.0" +__version__ = "0.4.1" diff --git a/whisper/mlx_whisper/audio.py b/whisper/mlx_whisper/audio.py index e04309c1..c8cca07c 100644 --- a/whisper/mlx_whisper/audio.py +++ b/whisper/mlx_whisper/audio.py @@ -3,7 +3,7 @@ import os from functools import lru_cache from subprocess import CalledProcessError, run -from typing import Union +from typing import Optional, Union import mlx.core as mx import numpy as np @@ -21,7 +21,7 @@ FRAMES_PER_SECOND = SAMPLE_RATE // HOP_LENGTH # 10ms per audio frame TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN # 20ms per audio token -def load_audio(file: str, sr: int = SAMPLE_RATE): +def load_audio(file: str = Optional[str], sr: int = SAMPLE_RATE, from_stdin=False): """ Open an audio file and read as mono waveform, resampling as necessary @@ -39,19 +39,21 @@ def load_audio(file: str, sr: int = SAMPLE_RATE): """ # This launches a subprocess to decode audio while down-mixing - # and resampling as necessary. Requires the ffmpeg CLI in PATH. + # and resampling as necessary. Requires the ffmpeg CLI in PATH. + if from_stdin: + cmd = ["ffmpeg", "-i", "pipe:0"] + else: + cmd = ["ffmpeg", "-nostdin", "-i", file] + # fmt: off - cmd = [ - "ffmpeg", - "-nostdin", + cmd.extend([ "-threads", "0", - "-i", file, "-f", "s16le", "-ac", "1", "-acodec", "pcm_s16le", "-ar", str(sr), "-" - ] + ]) # fmt: on try: out = run(cmd, capture_output=True, check=True).stdout diff --git a/whisper/mlx_whisper/cli.py b/whisper/mlx_whisper/cli.py index c2813338..7d08a043 100644 --- a/whisper/mlx_whisper/cli.py +++ b/whisper/mlx_whisper/cli.py @@ -2,9 +2,11 @@ import argparse import os +import pathlib import traceback import warnings +from . import audio from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE from .transcribe import transcribe from .writers import get_writer @@ -27,15 +29,24 @@ def build_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter ) - parser.add_argument( - "audio", nargs="+", type=str, help="Audio file(s) to transcribe" - ) + + parser.add_argument("audio", nargs="+", help="Audio file(s) to transcribe") + parser.add_argument( "--model", default="mlx-community/whisper-tiny", type=str, help="The model directory or hugging face repo", ) + parser.add_argument( + "--output-name", + type=str, + default=None, + help=( + "The name of transcription/translation output files before " + "--output-format extensions" + ), + ) parser.add_argument( "--output-dir", "-o", @@ -200,6 +211,7 @@ def main(): path_or_hf_repo: str = args.pop("model") output_dir: str = args.pop("output_dir") output_format: str = args.pop("output_format") + output_name: str = args.pop("output_name") os.makedirs(output_dir, exist_ok=True) writer = get_writer(output_format, output_dir) @@ -219,17 +231,25 @@ def main(): warnings.warn("--max-line-count has no effect without --max-line-width") if writer_args["max_words_per_line"] and writer_args["max_line_width"]: warnings.warn("--max-words-per-line has no effect with --max-line-width") - for audio_path in args.pop("audio"): + + for audio_obj in args.pop("audio"): + if audio_obj == "-": + # receive the contents from stdin rather than read a file + audio_obj = audio.load_audio(from_stdin=True) + + output_name = output_name or "content" + else: + output_name = output_name or pathlib.Path(audio_obj).stem try: result = transcribe( - audio_path, + audio_obj, path_or_hf_repo=path_or_hf_repo, **args, ) - writer(result, audio_path, **writer_args) + writer(result, output_name, **writer_args) except Exception as e: traceback.print_exc() - print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}") + print(f"Skipping {audio_obj} due to {type(e).__name__}: {str(e)}") if __name__ == "__main__": diff --git a/whisper/mlx_whisper/decoding.py b/whisper/mlx_whisper/decoding.py index 41c2ec6d..4e060cd5 100644 --- a/whisper/mlx_whisper/decoding.py +++ b/whisper/mlx_whisper/decoding.py @@ -58,11 +58,12 @@ def detect_language( logits = model.logits(x, mel)[:, 0] # collect detected languages; suppress all non-language tokens - mask = np.full(logits.shape[-1], -np.inf, dtype=np.float32) + mask = mx.full(logits.shape[-1], -mx.inf, dtype=mx.float32) mask[list(tokenizer.all_language_tokens)] = 0.0 - logits += mx.array(mask) + logits += mask language_tokens = mx.argmax(logits, axis=-1) language_token_probs = mx.softmax(logits, axis=-1) + language_token_probs = np.array(language_token_probs) language_probs = [ { c: language_token_probs[i, j].item() @@ -129,17 +130,12 @@ class DecodingResult: class Inference: - def __init__(self, model: "Whisper", initial_token_length: int): + def __init__(self, model: "Whisper"): self.model: "Whisper" = model - self.initial_token_length = initial_token_length self.kv_cache = None def logits(self, tokens: mx.array, audio_features: mx.array) -> mx.array: """Perform a forward pass on the decoder and return per-token logits""" - if tokens.shape[-1] > self.initial_token_length: - # only need to use the last token except in the first forward pass - tokens = tokens[:, -1:] - logits, self.kv_cache, _ = self.model.decoder( tokens, audio_features, kv_cache=self.kv_cache ) @@ -251,6 +247,11 @@ class TokenDecoder: raise NotImplementedError +@mx.compile +def categorical(logits, temp): + return mx.random.categorical(logits / temp) + + class GreedyDecoder(TokenDecoder): def __init__(self, temperature: float, eot: int): self.temperature = temperature @@ -262,10 +263,8 @@ class GreedyDecoder(TokenDecoder): if self.temperature == 0: next_tokens = logits.argmax(axis=-1) else: - next_tokens = mx.random.categorical(logits=logits / self.temperature) + next_tokens = categorical(logits, self.temperature) - next_tokens = mx.argmax(logits, axis=-1) - logits = logits.astype(mx.float32) logprobs = logits - mx.logsumexp(logits, axis=-1) current_logprobs = logprobs[mx.arange(logprobs.shape[0]), next_tokens] @@ -281,7 +280,7 @@ class GreedyDecoder(TokenDecoder): def finalize(self, tokens: mx.array, sum_logprobs: mx.array): # make sure each sequence has at least one EOT token at the end tokens = mx.pad(tokens, [(0, 0), (0, 0), (0, 1)], constant_values=self.eot) - return tokens, sum_logprobs.tolist() + return tokens, sum_logprobs class LogitFilter: @@ -340,10 +339,10 @@ class ApplyTimestampRules(LogitFilter): if self.tokenizer.no_timestamps is not None: mask[:, self.tokenizer.no_timestamps] = -np.inf - # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly - for k in range(tokens.shape[0]): - sampled_tokens = tokens[k, self.sample_begin :] - seq = sampled_tokens.tolist() + ## timestamps have to appear in pairs, except directly before EOT; mask logits accordingly + tokens = tokens.tolist() + for k in range(len(tokens)): + seq = tokens[k][self.sample_begin :] last_was_timestamp = ( len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin ) @@ -368,7 +367,7 @@ class ApplyTimestampRules(LogitFilter): last_timestamp += 1 mask[k, self.tokenizer.timestamp_begin : last_timestamp] = -np.inf - if tokens.shape[1] == self.sample_begin: + if len(tokens[0]) == self.sample_begin: # suppress generating non-timestamp tokens at the beginning mask[:, : self.tokenizer.timestamp_begin] = -np.inf @@ -380,16 +379,20 @@ class ApplyTimestampRules(LogitFilter): mask[:, last_allowed + 1 :] = -np.inf # if sum of probability over timestamps is above any other token, sample timestamp + mask = mx.array(mask) logprobs = logits - mx.logsumexp(logits, axis=-1) - for k in range(tokens.shape[0]): - timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp( - axis=-1 - ) - max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max() - if timestamp_logprob > max_text_token_logprob: - mask[k, : self.tokenizer.timestamp_begin] = -np.inf - - return logits + mx.array(mask, logits.dtype) + timestamp_logprob = logprobs[:, self.tokenizer.timestamp_begin :].logsumexp( + axis=-1, keepdims=True + ) + max_text_token_logprob = logprobs[:, : self.tokenizer.timestamp_begin].max( + axis=-1, keepdims=True + ) + mask[:, : self.tokenizer.timestamp_begin] = mx.where( + timestamp_logprob > max_text_token_logprob, + -mx.inf, + mask[:, : self.tokenizer.timestamp_begin], + ) + return logits + mask class DecodingTask: @@ -424,7 +427,7 @@ class DecodingTask: self.sot_index: int = self.initial_tokens.index(tokenizer.sot) # inference: implements the forward pass through the decoder, including kv caching - self.inference = Inference(model, len(self.initial_tokens)) + self.inference = Inference(model) # sequence ranker: implements how to rank a group of sampled sequences self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty) @@ -432,9 +435,6 @@ class DecodingTask: # decoder: implements how to select the next tokens, given the autoregressive distribution if options.beam_size is not None: raise NotImplementedError("Beam search decoder is not yet implemented") - # self.decoder = BeamSearchDecoder( - # options.beam_size, tokenizer.eot, self.inference, options.patience - # ) else: self.decoder = GreedyDecoder(options.temperature, tokenizer.eot) @@ -448,6 +448,7 @@ class DecodingTask: self.logit_filters.append( SuppressTokens(self._get_suppress_tokens(), model.dims.n_vocab) ) + if not options.without_timestamps: precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds max_initial_timestamp_index = None @@ -570,48 +571,59 @@ class DecodingTask: def _main_loop(self, audio_features: mx.array, tokens: mx.array): n_batch = tokens.shape[0] - sum_logprobs: mx.array = mx.zeros(n_batch) - no_speech_probs = [np.nan] * n_batch + sum_logprobs = mx.zeros(n_batch) - try: - for i in range(self.sample_len): - logits = self.inference.logits(tokens, audio_features) + def _step(inputs, audio_features, tokens, sum_logprobs): + pre_logits = self.inference.logits(inputs, audio_features) - if ( - i == 0 and self.tokenizer.no_speech is not None - ): # save no_speech_probs - probs_at_sot = mx.softmax( - logits[:, self.sot_index].astype(mx.float32), axis=-1 - ) - no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist() + # consider the logits at the last token only + logits = pre_logits[:, -1] - # now we need to consider the logits at the last token only - logits = logits[:, -1] + # apply the logit filters, e.g. for suppressing or applying penalty to + for logit_filter in self.logit_filters: + logits = logit_filter.apply(logits, tokens) - # apply the logit filters, e.g. for suppressing or applying penalty to - for logit_filter in self.logit_filters: - logits = logit_filter.apply(logits, tokens) + # expand the tokens tensor with the selected next tokens + tokens, completed, sum_logprobs = self.decoder.update( + tokens, logits, sum_logprobs + ) + return tokens, completed, sum_logprobs, pre_logits - # expand the tokens tensor with the selected next tokens - tokens, completed, sum_logprobs = self.decoder.update( - tokens, logits, sum_logprobs - ) + tokens, completed, sum_logprobs, pre_logits = _step( + tokens, audio_features, tokens, sum_logprobs + ) + if self.tokenizer.no_speech is not None: # compute no_speech_probs + probs_at_sot = mx.softmax(pre_logits[:, self.sot_index], axis=-1) + no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech] + else: + no_speech_probs = mx.full(n_batch, mx.nan) + mx.async_eval(completed, tokens, sum_logprobs, no_speech_probs) - if completed or tokens.shape[-1] > self.n_ctx: - break - finally: - self.inference.reset() + for i in range(1, self.sample_len): + inputs = tokens[:, -1:] + if tokens.shape[-1] > self.n_ctx: + break + next_tokens, next_completed, next_sum_logprobs, _ = _step( + inputs, audio_features, tokens, sum_logprobs + ) + mx.async_eval(next_completed, next_tokens, next_sum_logprobs) + if completed: + break + tokens = next_tokens + completed = next_completed + sum_logprobs = next_sum_logprobs return tokens, sum_logprobs, no_speech_probs def run(self, mel: mx.array) -> List[DecodingResult]: + self.inference.reset() self.decoder.reset() tokenizer: Tokenizer = self.tokenizer n_audio: int = mel.shape[0] audio_features: mx.array = self._get_audio_features(mel) # encoder forward pass - tokens: np.array = np.array(self.initial_tokens) - tokens = np.broadcast_to(tokens, (n_audio, len(self.initial_tokens))).copy() + tokens: mx.array = mx.array(self.initial_tokens) + tokens = mx.broadcast_to(tokens, (n_audio, len(self.initial_tokens))) # detect language if requested, overwriting the language token languages, language_probs = self._detect_language(audio_features, tokens) @@ -626,7 +638,6 @@ class DecodingTask: ] # repeat tokens by the group size, for beam search or best-of-n sampling - tokens = mx.array(tokens) if self.n_group > 1: tokens = tokens[:, None, :] tokens = mx.broadcast_to( @@ -649,7 +660,13 @@ class DecodingTask: # get the final candidates for each group, and slice between the first sampled token and EOT tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) - tokens = tokens[..., self.sample_begin :].tolist() + tokens = tokens[..., self.sample_begin :] + + # eval and convert to list + mx.eval(tokens, sum_logprobs, no_speech_probs) + tokens = tokens.tolist() + sum_logprobs = sum_logprobs.tolist() + no_speech_probs = no_speech_probs.tolist() tokens = [[t[: t.index(tokenizer.eot)] for t in s] for s in tokens] # select the top-ranked sample in each group diff --git a/whisper/mlx_whisper/load_models.py b/whisper/mlx_whisper/load_models.py index 6705385d..60766ab2 100644 --- a/whisper/mlx_whisper/load_models.py +++ b/whisper/mlx_whisper/load_models.py @@ -26,7 +26,10 @@ def load_model( model_args = whisper.ModelDimensions(**config) - weights = mx.load(str(model_path / "weights.npz")) + wf = model_path / "weights.safetensors" + if not wf.exists(): + wf = model_path / "weights.npz" + weights = mx.load(str(wf)) model = whisper.Whisper(model_args, dtype) diff --git a/whisper/mlx_whisper/transcribe.py b/whisper/mlx_whisper/transcribe.py index 786b4232..7057679b 100644 --- a/whisper/mlx_whisper/transcribe.py +++ b/whisper/mlx_whisper/transcribe.py @@ -293,6 +293,7 @@ def transcribe( decode_options["prompt"] = all_tokens[prompt_reset_since:] result: DecodingResult = decode_with_fallback(mel_segment) + tokens = np.array(result.tokens) if no_speech_threshold is not None: diff --git a/whisper/mlx_whisper/whisper.py b/whisper/mlx_whisper/whisper.py index e691792c..1c2b390e 100644 --- a/whisper/mlx_whisper/whisper.py +++ b/whisper/mlx_whisper/whisper.py @@ -80,12 +80,11 @@ class MultiHeadAttention(nn.Module): qk = q @ k if mask is not None: qk = qk + mask[:n_ctx, :n_ctx] - qk = qk.astype(mx.float32) - w = mx.softmax(qk, axis=-1).astype(q.dtype) + w = mx.softmax(qk, axis=-1, precise=True) out = (w @ v).transpose(0, 2, 1, 3) out = out.reshape(n_batch, n_ctx, n_state) - return out, qk + return out, qk.astype(mx.float32) class ResidualAttentionBlock(nn.Module): diff --git a/whisper/mlx_whisper/writers.py b/whisper/mlx_whisper/writers.py index 464ead18..cdb35063 100644 --- a/whisper/mlx_whisper/writers.py +++ b/whisper/mlx_whisper/writers.py @@ -1,10 +1,8 @@ # Copyright © 2024 Apple Inc. import json -import os +import pathlib import re -import sys -import zlib from typing import Callable, List, Optional, TextIO @@ -43,15 +41,13 @@ class ResultWriter: self.output_dir = output_dir def __call__( - self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs + self, result: dict, output_name: str, options: Optional[dict] = None, **kwargs ): - audio_basename = os.path.basename(audio_path) - audio_basename = os.path.splitext(audio_basename)[0] - output_path = os.path.join( - self.output_dir, audio_basename + "." + self.extension + output_path = (pathlib.Path(self.output_dir) / output_name).with_suffix( + f".{self.extension}" ) - with open(output_path, "w", encoding="utf-8") as f: + with output_path.open("wt", encoding="utf-8") as f: self.write_result(result, file=f, options=options, **kwargs) def write_result(