diff --git a/.gitignore b/.gitignore index f3dfe929..45445fc8 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ __pycache__/ # C extensions *.so +# Vim +*.swp + # Distribution / packaging .Python build/ diff --git a/llms/README.md b/llms/README.md index 75677865..20863041 100644 --- a/llms/README.md +++ b/llms/README.md @@ -20,6 +20,31 @@ The `mlx-lm` package also has: - [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md) - [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md) +### Quick Start + +To generate text with an LLM use: + +```bash +mlx_lm.generate --prompt "Hi!" +``` + +To chat with an LLM use: + +```bash +mlx_lm.chat +``` + +This will give you a chat REPL that you can use to interact with the LLM. The +chat context is preserved during the lifetime of the REPL. + +Commands in `mlx-lm` typically take command line options which let you specify +the model, sampling parameters, and more. Use `-h` to see a list of available +options for a command, e.g.: + +```bash +mlx_lm.generate -h +``` + ### Python API You can use `mlx-lm` as a module: @@ -138,7 +163,7 @@ mlx_lm.convert \ ### Long Prompts and Generations -MLX LM has some tools to scale efficiently to long prompts and generations: +`mlx-lm` has some tools to scale efficiently to long prompts and generations: - A rotating fixed-size key-value cache. - Prompt caching @@ -155,14 +180,14 @@ different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example: cat prompt.txt | mlx_lm.cache_prompt \ --model mistralai/Mistral-7B-Instruct-v0.3 \ --prompt - \ - --kv-cache-file mistral_prompt.safetensors + --prompt-cache-file mistral_prompt.safetensors ``` Then use the cached prompt with `mlx_lm.generate`: ``` mlx_lm.generate \ - --kv-cache-file mistral_prompt.safetensors \ + --prompt-cache-file mistral_prompt.safetensors \ --prompt "\nSummarize the above text." ``` @@ -170,9 +195,15 @@ The cached prompt is treated as a prefix to the supplied prompt. Also notice when using a cached prompt, the model to use is read from the cache and need not be supplied explicitly. +Prompt caching can also be used in the Python API in order to to avoid +recomputing the prompt. This is useful in multi-turn dialogues or across +requests that use the same context. See the +[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py) +for more usage details. + ### Supported Models -MLX LM supports thousands of Hugging Face format LLMs. If the model you want to +`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to run is not supported, file an [issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet, submit a pull request. diff --git a/llms/a.py b/llms/a.py deleted file mode 100644 index 94a72239..00000000 --- a/llms/a.py +++ /dev/null @@ -1,26 +0,0 @@ -import mlx_lm - -# model, tokenizer = mlx_lm.load("mlx-community/SmolLM-1.7B-Instruct-fp16") -model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Qwen2-0.5B-8bit-Instruct") -draft_model, draft_tokenizer = mlx_lm.load("mlx-community/SmolLM-135M-Instruct-4bit") - -# https://github.com/hemingkx/Spec-Bench/blob/main/data/spec_bench/question.jsonl -prompt = "Develop a Python program that reads all the text files under a directory and returns top-5 words with the most number of occurrences." - -prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], - tokenize=False, - add_generation_prompt=True, -) - -mlx_lm.generate( - model, - tokenizer, - prompt=prompt, - verbose=True, - max_tokens=500, - temp=1.0, - min_p=0.1, - repetition_penalty=1.2, - # draft_model=draft_model, -) diff --git a/llms/b.py b/llms/b.py deleted file mode 100644 index 54573ba4..00000000 --- a/llms/b.py +++ /dev/null @@ -1,41 +0,0 @@ -import mlx_lm -import random -import string - -model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Qwen2-0.5B-8bit-Instruct") - -capital_letters = string.ascii_uppercase -distinct_pairs = [ - (a, b) for i, a in enumerate(capital_letters) for b in capital_letters[i + 1 :] -] - -num_prompts = 16 -prompt_template = "Think of a real word containing both the letters {l1} and {l2}. Then, say 3 sentences which use the word." -prompts = [ - prompt_template.format(l1=p[0], l2=p[1]) - for p in random.sample(distinct_pairs, num_prompts) -] -prompts = [ - "Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?", - "James writes a 3-page letter to 2 different friends twice a week. How many pages does he write a year?", - "Tina makes $18.00 an hour. If she works more than 8 hours per shift, she is eligible for overtime, which is paid by your hourly wage + 1/2 your hourly wage. If she works 10 hours every day for 5 days, how much money does she make?" -] -prompts = [ - tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], - tokenize=False, - add_generation_prompt=True, - ) - for prompt in prompts -] - -response = mlx_lm.batch_generate( - model, - tokenizer, - prompts=prompts, - max_tokens=512, - verbose=True, - temp=1.0, - min_p=0.1, - repetition_penalty=1.2, -) diff --git a/llms/c.py b/llms/c.py deleted file mode 100644 index 6f4034ab..00000000 --- a/llms/c.py +++ /dev/null @@ -1,11 +0,0 @@ -import mlx_lm - -model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Meta-Llama-3.1-8B-4bit") - -for s in mlx_lm.stream_generate( - model, - tokenizer, - prompt="Meta Llama 3.1 is a ", - max_tokens=100, -): - print(s, end="", flush=True) diff --git a/llms/d.py b/llms/d.py deleted file mode 100644 index 34ebf620..00000000 --- a/llms/d.py +++ /dev/null @@ -1,11 +0,0 @@ -import mlx_lm - -model, tokenizer = mlx_lm.load("/Users/llwu/models/mlx/Meta-Llama-3.1-8B-4bit") - -for s in mlx_lm.stream_generate( - model, - tokenizer, - prompt=["Meta Llama 3.1 is a ", "Google Gemma 2 is a "], - max_tokens=20, -): - print(s[0].ljust(30) + s[1], flush=True) diff --git a/llms/issue.txt b/llms/issue.txt deleted file mode 100644 index 245b4d92..00000000 --- a/llms/issue.txt +++ /dev/null @@ -1,21 +0,0 @@ -## Steps to reproduce - -Run the following with and without `prefill_step_size=2` commented out: - -```py -import mlx_lm - -model, tokenizer = mlx_lm.load('/Users/llwu/models/mlx/Meta-Llama-3.1-8B-4bit') - -mlx_lm.generate( - model, - tokenizer, - prompt="69 + 420= ", - verbose=True, - max_tokens=10, - max_kv_size=5, - prefill_step_size=2, -) -``` - -The output is different. I notice that the RotatingKVCache has length 5 with prefill and length 7 without. diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 8110c823..70239db6 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.18.2" +__version__ = "0.19.1" diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 9829efb4..04e75a3e 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -7,13 +7,14 @@ import time import mlx.core as mx -from .utils import load, make_kv_caches +from .models.cache import make_prompt_cache, save_prompt_cache +from .utils import load def setup_arg_parser(): """Set up and return the argument parser.""" parser = argparse.ArgumentParser( - description="Cache the KV cache of a prompt to be reused with mlx_lm.generate" + description="Cache the state of a prompt to be reused with mlx_lm.generate" ) parser.add_argument( "--model", @@ -60,7 +61,9 @@ def setup_arg_parser(): help="Set the maximum key-value cache size", ) parser.add_argument( - "--kv-cache-file", help="The file to save the KV caches in", required=True + "--prompt-cache-file", + help="The file to save the prompt cache in", + required=True, ) parser.add_argument( "--prompt", @@ -115,7 +118,7 @@ def main(): else: prompt = args.prompt - cache = make_kv_caches(model, args.max_kv_size) + cache = make_prompt_cache(model, args.max_kv_size) y = mx.array(tokenizer.encode(prompt)) # Process the prompt @@ -137,16 +140,12 @@ def main(): print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") print("Saving...") - cache_dict = {} - for i, c in enumerate(cache): - cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :] - cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :] metadata = {} metadata["model"] = args.model metadata["chat_template"] = tokenizer.chat_template metadata["tokenizer_config"] = json.dumps(tokenizer_config) - metadata["max_kv_size"] = str(args.max_kv_size) - mx.save_safetensors(args.kv_cache_file, cache_dict, metadata) + print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") + save_prompt_cache(args.prompt_cache_file, cache, metadata) if __name__ == "__main__": diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py new file mode 100644 index 00000000..7968a868 --- /dev/null +++ b/llms/mlx_lm/chat.py @@ -0,0 +1,82 @@ +# Copyright © 2023-2024 Apple Inc. + +import argparse +import json + +import mlx.core as mx + +from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache +from .utils import load, stream_generate + +DEFAULT_TEMP = 0.0 +DEFAULT_TOP_P = 1.0 +DEFAULT_SEED = 0 +DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" + + +def setup_arg_parser(): + """Set up and return the argument parser.""" + parser = argparse.ArgumentParser(description="Chat with an LLM") + parser.add_argument( + "--model", + type=str, + help="The path to the local model directory or Hugging Face repo.", + default=DEFAULT_MODEL, + ) + parser.add_argument( + "--adapter-path", + type=str, + help="Optional path for the trained adapter weights and config.", + ) + parser.add_argument( + "--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature" + ) + parser.add_argument( + "--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p" + ) + parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed") + parser.add_argument( + "--max-kv-size", + type=int, + help="Set the maximum key-value cache size", + default=None, + ) + return parser + + +def main(): + parser = setup_arg_parser() + args = parser.parse_args() + + mx.random.seed(args.seed) + + model, tokenizer = load( + args.model, + adapter_path=args.adapter_path, + tokenizer_config={"trust_remote_code": True}, + ) + + print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.") + prompt_cache = make_prompt_cache(model, args.max_kv_size) + while True: + query = input(">> ") + if query == "q": + break + messages = [{"role": "user", "content": query}] + prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + for response in stream_generate( + model, + tokenizer, + prompt, + temp=args.temp, + top_p=args.top_p, + prompt_cache=prompt_cache, + ): + print(response, flush=True, end="") + print() + + +if __name__ == "__main__": + main() diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py new file mode 100644 index 00000000..3bf01688 --- /dev/null +++ b/llms/mlx_lm/examples/chat.py @@ -0,0 +1,53 @@ +# Copyright © 2024 Apple Inc. + +""" +An example of a multi-turn chat with prompt caching. +""" + +from mlx_lm import generate, load +from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache + +model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit") + +# Make the initial prompt cache for the model +prompt_cache = make_prompt_cache(model) + +# User turn +prompt = "Hi my name is ." +messages = [{"role": "user", "content": prompt}] +prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) + +# Assistant response +response = generate( + model, + tokenizer, + prompt=prompt, + verbose=True, + temp=0.0, + prompt_cache=prompt_cache, +) + +# User turn +prompt = "What's my name?" +messages = [{"role": "user", "content": prompt}] +prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) + +# Assistant response +response = generate( + model, + tokenizer, + prompt=prompt, + verbose=True, + temp=0.0, + prompt_cache=prompt_cache, +) + +# Save the prompt cache to disk to reuse it at a later time +save_prompt_cache("mistral_prompt.safetensors", prompt_cache) + +# Load the prompt cache from disk +prompt_cache = load_prompt_cache("mistral_prompt.safetensors") diff --git a/llms/mlx_lm/examples/generate_response.py b/llms/mlx_lm/examples/generate_response.py index af599c1b..25730617 100644 --- a/llms/mlx_lm/examples/generate_response.py +++ b/llms/mlx_lm/examples/generate_response.py @@ -1,3 +1,5 @@ +# Copyright © 2024 Apple Inc. + from mlx_lm import generate, load # Specify the checkpoint diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 537bd853..0bf98ab2 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -6,13 +6,15 @@ import sys import mlx.core as mx +from .models.cache import load_prompt_cache from .utils import generate, load DEFAULT_PROMPT = "hello" DEFAULT_MAX_TOKENS = 100 -DEFAULT_TEMP = 0.6 +DEFAULT_TEMP = 0.0 DEFAULT_TOP_P = 1.0 DEFAULT_SEED = 0 +DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit" def str2bool(string): @@ -25,7 +27,11 @@ def setup_arg_parser(): parser.add_argument( "--model", type=str, - help="The path to the local model directory or Hugging Face repo.", + help=( + "The path to the local model directory or Hugging Face repo. " + f"If no model is specified, then {DEFAULT_MODEL} is used." + ), + default=None, ) parser.add_argument( "--adapter-path", @@ -96,7 +102,7 @@ def setup_arg_parser(): default=None, ) parser.add_argument( - "--kv-cache-file", + "--prompt-cache-file", type=str, default=None, help="A file containing saved KV caches to avoid recomputing them", @@ -131,24 +137,6 @@ def colorprint_by_t0(s, t0): colorprint(color, s) -def load_kv_cache_from_file(kv_cache_file): - if kv_cache_file is None: - return None, None - - kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True) - cache_per_layer = {} - for k, x in kv_cache.items(): - layer, kv_type = k.split("_") - if layer not in cache_per_layer: - cache_per_layer[layer] = {} - cache_per_layer[layer][kv_type] = x - - cache_history = [None] * len(cache_per_layer) - for layer, c in cache_per_layer.items(): - cache_history[int(layer)] = (c["keys"], c["values"]) - return cache_history, metadata - - def main(): parser = setup_arg_parser() args = parser.parse_args() @@ -158,22 +146,33 @@ def main(): if args.cache_limit_gb is not None: mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - # Load the kv cache and metadata if a kv cache file is provided - cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file) + # Load the prompt cache and metadata if a cache file is provided + using_cache = args.prompt_cache_file is not None + if using_cache: + prompt_cache, metadata = load_prompt_cache( + args.prompt_cache_file, return_metadata=True + ) # Building tokenizer_config tokenizer_config = ( - {} if cache_history is None else json.loads(metadata["tokenizer_config"]) + {} if not using_cache else json.loads(metadata["tokenizer_config"]) ) if args.trust_remote_code: tokenizer_config["trust_remote_code"] = True if args.eos_token is not None: tokenizer_config["eos_token"] = args.eos_token - # If no model path is provided then use the one in the kv cache history model_path = args.model - if cache_history is not None and model_path is None: - model_path = metadata["model"] + if using_cache: + if model_path is None: + model_path = metadata["model"] + elif model_path != metadata["model"]: + raise ValueError( + f"Providing a different model ({model_path}) than that " + f"used to create the prompt cache ({metadata['model']}) " + "is an error." + ) + model_path = model_path or DEFAULT_MODEL model, tokenizer = load( model_path, @@ -184,7 +183,7 @@ def main(): if args.use_default_chat_template: if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template - elif cache_history is not None: + elif using_cache: tokenizer.chat_template = metadata["chat_template"] if not args.ignore_chat_template and ( @@ -203,7 +202,7 @@ def main(): # Treat the prompt as a suffix assuming that the prefix is in the # stored kv cache. - if cache_history is not None: + if using_cache: test_prompt = tokenizer.apply_chat_template( [{"role": "user", "content": ""}], tokenize=False, @@ -217,12 +216,6 @@ def main(): raise ValueError("Cannot use --colorize with --verbose=False") formatter = colorprint_by_t0 if args.colorize else None - # Determine the max kv size from the kv cache or passed arguments - max_kv_size = args.max_kv_size - if cache_history is not None: - max_kv_size = metadata["max_kv_size"] - max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None - response = generate( model, tokenizer, @@ -232,8 +225,8 @@ def main(): formatter=formatter, temp=args.temp, top_p=args.top_p, - max_kv_size=max_kv_size, - cache_history=cache_history, + max_kv_size=args.max_kv_size, + prompt_cache=prompt_cache if using_cache else None, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index dc19dd05..3628a808 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -2,145 +2,9 @@ import inspect from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, Optional import mlx.core as mx -import mlx.nn as nn - - -class KVCache: - - def __init__(self, head_dim, n_kv_heads): - self.n_kv_heads = n_kv_heads - if isinstance(head_dim, int): - self.k_head_dim = self.v_head_dim = head_dim - elif isinstance(head_dim, tuple) and len(head_dim) == 2: - self.k_head_dim, self.v_head_dim = head_dim - else: - raise ValueError("head_dim must be an int or a tuple of two ints") - self.keys = None - self.values = None - self.offset = 0 - self.step = 256 - - def update_and_fetch(self, keys, values): - prev = self.offset - if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: - B = keys.shape[0] - n_steps = (self.step + keys.shape[2] - 1) // self.step - k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim) - v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim) - new_k = mx.zeros(k_shape, keys.dtype) - new_v = mx.zeros(v_shape, values.dtype) - if self.keys is not None: - if prev % self.step != 0: - self.keys = self.keys[..., :prev, :] - self.values = self.values[..., :prev, :] - self.keys = mx.concatenate([self.keys, new_k], axis=2) - self.values = mx.concatenate([self.values, new_v], axis=2) - else: - self.keys, self.values = new_k, new_v - - self.offset += keys.shape[2] - self.keys[..., prev : self.offset, :] = keys - self.values[..., prev : self.offset, :] = values - return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] - - @property - def state(self): - return self.keys, self.values - - -class RotatingKVCache: - - def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256): - self.n_kv_heads = n_kv_heads - if isinstance(head_dim, int): - self.k_head_dim = self.v_head_dim = head_dim - elif isinstance(head_dim, tuple) and len(head_dim) == 2: - self.k_head_dim, self.v_head_dim = head_dim - else: - raise ValueError("head_dim must be an int or a tuple of two ints") - self.keep = keep - self.keys = None - self.values = None - self.offset = 0 - self.max_size = max_size - self.step = step - self._idx = 0 - - def _trim(self, trim_size, v, append=None): - to_cat = [] - if trim_size > 0: - to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]] - else: - to_cat = [v] - if append is not None: - to_cat.append(append) - return mx.concatenate(to_cat, axis=2) - - def update_and_fetch(self, keys, values): - prev = self.offset - B, _, S = keys.shape[:3] - - # Prefill mode - if S > 1: - if self.keys is None: - self.keys = keys - self.values = values - else: - # The largest size is self.max_size + S - 1 to ensure - # every token gets at least self.max_size context - trim_size = self.keys.shape[2] - self.max_size + 1 - self.keys = self._trim(trim_size, self.keys, keys) - self.values = self._trim(trim_size, self.values, values) - self.offset += S - self._idx = self.keys.shape[2] - return self.keys, self.values - - # Generation mode - # May not have hit the max size yet, so potentially - # keep growing the cache - if self.keys is None or ( - prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size - ): - new_size = min(self.step, self.max_size - prev) - k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim) - v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim) - new_k = mx.zeros(k_shape, keys.dtype) - new_v = mx.zeros(v_shape, values.dtype) - if self.keys is not None: - self.keys = mx.concatenate([self.keys, new_k], axis=2) - self.values = mx.concatenate([self.values, new_v], axis=2) - else: - self.keys, self.values = new_k, new_v - self._idx = prev - - # Trim if needed - trim_size = self.keys.shape[2] - self.max_size - if trim_size > 0: - self.keys = self._trim(trim_size, self.keys) - self.values = self._trim(trim_size, self.values) - self._idx = self.max_size - - # Rotate - if self._idx == self.max_size: - self._idx = self.keep - - # Assign - self.keys[..., self._idx : self._idx + 1, :] = keys - self.values[..., self._idx : self._idx + 1, :] = values - self.offset += 1 - self._idx += 1 - - # If the buffer is not full, slice off the end - if self.offset < self.max_size: - return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] - return self.keys, self.values - - @property - def state(self): - return self.keys, self.values @dataclass @@ -156,25 +20,30 @@ class BaseModelArgs: ) -def create_additive_causal_mask(N: int, offset: int = 0): +def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None): rinds = mx.arange(offset + N) linds = mx.arange(offset, offset + N) if offset else rinds - mask = linds[:, None] < rinds[None] + linds = linds[:, None] + rinds = rinds[None] + mask = linds < rinds + if window_size is not None: + mask = mask | (linds > rinds + window_size) return mask * -1e9 def create_attention_mask(h: mx.array, cache: Optional[Any] = None): T = h.shape[1] if T > 1: + window_size = None + offset = 0 if cache is not None and cache[0] is not None: c = cache[0] - if isinstance(c, RotatingKVCache): + if hasattr(c, "max_size"): offset = min(c.max_size - 1, c.offset) + window_size = c.max_size else: offset = c.offset - else: - offset = 0 - mask = create_additive_causal_mask(T, offset) + mask = create_causal_mask(T, offset, window_size=window_size) mask = mask.astype(h.dtype) else: mask = None diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py new file mode 100644 index 00000000..b06422e5 --- /dev/null +++ b/llms/mlx_lm/models/cache.py @@ -0,0 +1,333 @@ +# Copyright © 2023-2024 Apple Inc. + +from typing import Any, Dict, List, Optional + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten, tree_unflatten + + +def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]: + """ + Construct the model's cache for use when cgeneration. + + This function will defer the cache construction to the model if it has a + ``make_cache`` method, otherwise it will make a default KV cache. + + Args: + model (nn.Module): The language model. + max_kv_size (Optional[int]): If provided and the model does not have a + ``make_cache`` method, a ``RotatingKVCache`` is used with a maximum + size of ``max_kv_size`` + """ + if hasattr(model, "make_cache"): + return model.make_cache() + + num_layers = len(model.layers) + if max_kv_size is not None: + return [ + RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers) + ] + else: + return [KVCache() for _ in range(num_layers)] + + +def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}): + """ + Save a pre-computed prompt cache to a file. + + Args: + file_name (str): The ``.safetensors`` file name. + cache (List[Any]): The model state. + metadata (Dict[str, str]): Optional metadata to save along with model + state. + """ + cache_data = [c.state for c in cache] + cache_info = [c.meta_state for c in cache] + cache_data = dict(tree_flatten(cache_data)) + cache_classes = [type(c).__name__ for c in cache] + cache_metadata = [cache_info, metadata, cache_classes] + cache_metadata = dict(tree_flatten(cache_metadata)) + mx.save_safetensors(file_name, cache_data, cache_metadata) + + +def load_prompt_cache(file_name, return_metadata=False): + """ + Load a prompt cache from a file. + + Args: + file_name (str): The ``.safetensors`` file name. + return_metadata (bool): Whether or not to return metadata. + Default: ``False``. + + Returns: + List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and + the metadata if requested. + """ + arrays, cache_metadata = mx.load(file_name, return_metadata=True) + arrays = tree_unflatten(list(arrays.items())) + cache_metadata = tree_unflatten(list(cache_metadata.items())) + info, metadata, classes = cache_metadata + cache = [globals()[c]() for c in classes] + for c, state, meta_state in zip(cache, arrays, info): + c.state = state + c.meta_state = meta_state + if return_metadata: + return cache, metadata + return cache + + +def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: + """ + Trim the model's cache by the given number of tokens. + + This function will trim the cache if possible (in-place) and return the + number of tokens that were trimmed. + + Args: + cache (List[Any]): The model's cache. + num_tokens (int): The number of tokens to trim. + + Returns: + (int): The number of tokens that were trimmed. + """ + if not all(c.is_trimmable() for c in cache) or len(cache) == 0: + return 0 + return [c.trim(num_tokens) for c in cache][0] + + +class _BaseCache: + @property + def state(self): + return [] + + @state.setter + def state(self, v): + if v is not None and v: + raise ValueError("This cache has no state but a state was set.") + + @property + def meta_state(self): + return "" + + @meta_state.setter + def meta_state(self, v): + if v is not None and v: + raise ValueError("This cache has no meta_state but a meta_state was set.") + + def is_trimmable(self): + return False + + +class KVCache(_BaseCache): + def __init__(self): + self.keys = None + self.values = None + self.offset = 0 + self.step = 256 + + def update_and_fetch(self, keys, values): + prev = self.offset + if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: + B, n_kv_heads, _, k_head_dim = keys.shape + v_head_dim = values.shape[3] + n_steps = (self.step + keys.shape[2] - 1) // self.step + k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim) + v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim) + new_k = mx.zeros(k_shape, keys.dtype) + new_v = mx.zeros(v_shape, values.dtype) + if self.keys is not None: + if prev % self.step != 0: + self.keys = self.keys[..., :prev, :] + self.values = self.values[..., :prev, :] + self.keys = mx.concatenate([self.keys, new_k], axis=2) + self.values = mx.concatenate([self.values, new_v], axis=2) + else: + self.keys, self.values = new_k, new_v + + self.offset += keys.shape[2] + self.keys[..., prev : self.offset, :] = keys + self.values[..., prev : self.offset, :] = values + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + + @property + def state(self): + if self.offset == self.keys.shape[2]: + return self.keys, self.values + else: + return ( + self.keys[..., : self.offset, :], + self.values[..., : self.offset, :], + ) + + @state.setter + def state(self, v): + self.keys, self.values = v + self.offset = self.keys.shape[2] + + def is_trimmable(self): + return True + + def trim(self, n): + n = min(self.offset, n) + self.offset -= n + return n + + +class RotatingKVCache(_BaseCache): + + def __init__(self, max_size=None, keep=0, step=256): + self.keep = keep + self.keys = None + self.values = None + self.offset = 0 + self.max_size = max_size + self.step = step + self._idx = 0 + + def _trim(self, trim_size, v, append=None): + to_cat = [] + if trim_size > 0: + to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]] + else: + to_cat = [v] + if append is not None: + to_cat.append(append) + return mx.concatenate(to_cat, axis=2) + + def _temporal_order(self, v): + """ + Rearrange the cache into temporal order, slicing off the end if unused. + """ + if self._idx == v.shape[2]: + return v + elif self._idx < self.offset: + return mx.concatenate( + [ + v[..., : self.keep, :], + v[..., self._idx :, :], + v[..., self.keep : self._idx, :], + ], + axis=2, + ) + else: + return v[..., : self._idx, :] + + def _update_concat(self, keys, values): + if self.keys is None: + self.keys = keys + self.values = values + else: + # Put the keys/values in temporal order to + # preserve context + self.keys = self._temporal_order(self.keys) + self.values = self._temporal_order(self.values) + + # The largest size is self.max_size + S - 1 to ensure + # every token gets at least self.max_size context + trim_size = self._idx - self.max_size + 1 + self.keys = self._trim(trim_size, self.keys, keys) + self.values = self._trim(trim_size, self.values, values) + self.offset += keys.shape[2] + self._idx = self.keys.shape[2] + return self.keys, self.values + + def _update_in_place(self, keys, values): + # May not have hit the max size yet, so potentially + # keep growing the cache + B, n_kv_heads, S, k_head_dim = keys.shape + prev = self.offset + if self.keys is None or ( + prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size + ): + v_head_dim = values.shape[3] + new_size = min(self.step, self.max_size - prev) + k_shape = (B, n_kv_heads, new_size, k_head_dim) + v_shape = (B, n_kv_heads, new_size, v_head_dim) + new_k = mx.zeros(k_shape, keys.dtype) + new_v = mx.zeros(v_shape, values.dtype) + if self.keys is not None: + self.keys = mx.concatenate([self.keys, new_k], axis=2) + self.values = mx.concatenate([self.values, new_v], axis=2) + else: + self.keys, self.values = new_k, new_v + self._idx = prev + + # Trim if needed + trim_size = self.keys.shape[2] - self.max_size + if trim_size > 0: + self.keys = self._trim(trim_size, self.keys) + self.values = self._trim(trim_size, self.values) + self._idx = self.max_size + + # Rotate + if self._idx == self.max_size: + self._idx = self.keep + + # Assign + self.keys[..., self._idx : self._idx + S, :] = keys + self.values[..., self._idx : self._idx + S, :] = values + self.offset += S + self._idx += S + + # If the buffer is not full, slice off the end + if self.offset < self.max_size: + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + return self.keys, self.values + + def update_and_fetch(self, keys, values): + if keys.shape[2] == 1: + return self._update_in_place(keys, values) + return self._update_concat(keys, values) + + @property + def state(self): + if self.offset < self.keys.shape[2]: + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + else: + return self.keys, self.values + + @state.setter + def state(self, v): + self.keys, self.values = v + + @property + def meta_state(self): + return tuple( + map(str, (self.keep, self.max_size, self.step, self.offset, self._idx)) + ) + + @meta_state.setter + def meta_state(self, v): + self.keep, self.max_size, self.step, self.offset, self._idx = map( + int, + v, + ) + + def is_trimmable(self): + return self.offset < self.max_size + + def trim(self, n): + n = min(self.offset, n) + self.offset -= n + self._idx -= n + return n + + +class MambaCache(_BaseCache): + def __init__(self): + self.cache = [None, None] + + def __setitem__(self, idx, value): + self.cache[idx] = value + + def __getitem__(self, idx): + return self.cache[idx] + + @property + def state(self): + return self.cache + + @state.setter + def state(self, v): + self.cache = v diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index cfcf2945..057c816d 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -69,7 +69,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -129,7 +129,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.input_layernorm(x) attn_h = self.self_attn(h, mask, cache) @@ -190,11 +190,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index f0214549..3b7e83d7 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -49,7 +49,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: qkv = self.Wqkv(x) @@ -92,7 +92,7 @@ class NormAttnNorm(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.attn(self.norm_1(x), mask=mask, cache=cache) x = h + x @@ -179,7 +179,7 @@ class DecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r, h = self.norm_attn_norm(x, mask, cache) out = self.ffn(h) + r @@ -249,11 +249,3 @@ class Model(nn.Module): experts = [(s, sv.T) for s, sv in experts] new_weights.update(experts) return new_weights - - @property - def head_dim(self): - return self.args.d_model // self.args.n_heads - - @property - def n_kv_heads(self): - return self.args.attn_config["kv_n_heads"] diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py index dcfa331c..03cb3b1a 100644 --- a/llms/mlx_lm/models/deepseek.py +++ b/llms/mlx_lm/models/deepseek.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from typing import Dict, Optional +from typing import Any, Dict, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .switch_layers import SwitchGLU @@ -77,7 +77,7 @@ class DeepseekAttention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, _ = x.shape @@ -108,8 +108,8 @@ class DeepseekMLP(nn.Module): def __init__( self, config: ModelArgs, - hidden_size: int | None = None, - intermediate_size: int | None = None, + hidden_size: Optional[int] = None, + intermediate_size: Optional[int] = None, ): super().__init__() self.config = config @@ -188,7 +188,7 @@ class DeepseekDecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -210,7 +210,7 @@ class DeepseekModel(nn.Module): def __call__( self, x: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.embed_tokens(x) mask = create_attention_mask(h, cache) @@ -235,7 +235,7 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ): out = self.model(inputs, cache) return self.lm_head(out) @@ -256,11 +256,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 602a9710..17d061a8 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -2,12 +2,12 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .switch_layers import SwitchGLU @@ -38,7 +38,7 @@ class ModelArgs(BaseModelArgs): max_position_embeddings: int = 2048 rms_norm_eps: float = 1e-6 rope_theta: float = 10000.0 - rope_scaling: Optional[Dict] = None + rope_scaling: Dict = None attention_bias: bool = False @@ -172,12 +172,11 @@ class DeepseekV2Attention(nn.Module): bias=config.attention_bias, ) - if self.config.rope_scaling is not None: - mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) - scaling_factor = self.config.rope_scaling["factor"] - if mscale_all_dim: - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) - self.scale = self.scale * mscale * mscale + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scale = self.scale * mscale * mscale rope_kwargs = { key: self.config.rope_scaling[key] @@ -202,7 +201,7 @@ class DeepseekV2Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -347,7 +346,7 @@ class DeepseekV2DecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -370,7 +369,7 @@ class DeepseekV2Model(nn.Module): def __call__( self, x: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.embed_tokens(x) mask = create_attention_mask(h, cache) @@ -395,7 +394,7 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ): out = self.model(inputs, cache) return self.lm_head(out) @@ -416,14 +415,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return ( - self.args.qk_nope_head_dim + self.args.qk_rope_head_dim, - self.args.v_head_dim, - ) - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index c6150284..61de781e 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -60,7 +60,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -113,7 +113,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -173,11 +173,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.head_dim - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index 1d410a15..ccc327a8 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -64,7 +64,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) @@ -135,13 +135,11 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: - r = self.self_attn(self.input_layernorm(x.astype(mx.float32)), mask, cache) + r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + self.post_attention_layernorm(r) - r = self.mlp(self.pre_feedforward_layernorm(h).astype(mx.float16)).astype( - mx.float32 - ) + r = self.mlp(self.pre_feedforward_layernorm(h)) out = h + self.post_feedforward_layernorm(r) return out @@ -200,11 +198,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.head_dim - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index 8a770936..97d9a8ff 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -46,7 +46,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -100,7 +100,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attn(self.ln_1(x), mask, cache) h = x + r @@ -196,11 +196,3 @@ class Model(nn.Module): @property def layers(self): return self.model.h - - @property - def head_dim(self): - return self.args.n_embd // self.args.n_head - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 652eb9e4..068046ea 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -57,7 +57,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -114,7 +114,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attn(self.ln_1(x), mask, cache) h = x + r @@ -184,11 +184,3 @@ class Model(nn.Module): @property def layers(self): return self.transformer.h - - @property - def head_dim(self): - return self.args.n_embd // self.args.n_head - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index c2aaa9ea..9f662491 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -60,7 +60,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -120,7 +120,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: residual = x # NeoX runs attention and feedforward network in parallel. @@ -214,11 +214,3 @@ class Model(nn.Module): @property def layers(self): return self.model.h - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index bcc0cf0c..5264cb57 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -116,7 +116,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -171,7 +171,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attention(self.attention_norm(x), mask, cache) h = x + r @@ -236,11 +236,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index c4a947a5..7da6b333 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -171,7 +171,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -233,7 +233,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -303,13 +303,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return ( - self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads - ) - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 26408426..d2740dc1 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -7,6 +7,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs +from .cache import MambaCache @dataclass @@ -45,21 +46,6 @@ class ModelArgs(BaseModelArgs): self.time_step_rank = math.ceil(self.hidden_size / 16) -class MambaCache: - def __init__(self): - self.cache = [None, None] - - def __setitem__(self, idx, value): - self.cache[idx] = value - - def __getitem__(self, idx): - return self.cache[idx] - - @property - def state(self): - return self.cache - - class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias=True, padding=0): super().__init__() @@ -223,7 +209,7 @@ class Model(nn.Module): weights[k] = v.moveaxis(2, 1) return weights - def make_cache(self, batch_size: int = 1): + def make_cache(self): return [MambaCache() for _ in range(len(self.layers))] @property diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index df0670be..4ac3c3b4 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -85,7 +85,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ): B, L, _ = x.shape @@ -135,7 +135,7 @@ class DecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers)) @@ -205,11 +205,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 2db57752..20944fe3 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -2,7 +2,7 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -66,7 +66,7 @@ class MixtralAttention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -138,7 +138,7 @@ class MixtralDecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -215,11 +215,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py index ef55d1d7..3ea06e27 100644 --- a/llms/mlx_lm/models/nemotron.py +++ b/llms/mlx_lm/models/nemotron.py @@ -2,12 +2,12 @@ from dataclasses import dataclass from functools import partial -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -94,7 +94,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, _ = x.shape @@ -151,7 +151,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -215,13 +215,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return ( - self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads - ) - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 59849c96..3627df06 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -1,8 +1,8 @@ # Copyright © 2023-2024 Apple Inc. +import sys from dataclasses import dataclass -from sys import exit -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -13,7 +13,7 @@ try: import hf_olmo except ImportError: print("To run olmo install ai2-olmo: pip install ai2-olmo") - exit(1) + sys.exit(1) @dataclass @@ -68,7 +68,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -98,7 +98,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attend(self.att_norm(x), mask, cache) h = x + r @@ -174,11 +174,3 @@ class Model(nn.Module): @property def layers(self): return self.model.transformer.blocks - - @property - def head_dim(self): - return self.args.d_model // self.args.n_heads - - @property - def n_kv_heads(self): - return self.args.n_heads diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 19d3c027..090e21c6 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -80,7 +80,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -152,7 +152,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attn(self.attn_norm(x), mask, cache) h = x + r @@ -218,11 +218,3 @@ class Model(nn.Module): @property def layers(self): return self.transformer.layers - - @property - def head_dim(self): - return self.args.head_dim - - @property - def n_kv_heads(self): - return self.args.num_kv_heads diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index fd3fd709..56b383b2 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -162,19 +162,11 @@ class Model(nn.Module): def __call__( self, x: mx.array, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: y = self.model(x, cache) return self.lm_head(y) @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index 112ade7d..9ef76f04 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .su_rope import SuScaledRotaryEmbedding @@ -84,7 +84,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -143,7 +143,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -202,11 +202,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index 665dbc73..6b0759b4 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -3,12 +3,12 @@ import math from dataclasses import dataclass from functools import partial -from typing import Dict, Optional, Tuple, Union +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -22,14 +22,14 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int layer_norm_epsilon: float vocab_size: int - num_key_value_heads: Optional[int] = None + num_key_value_heads: int mup_attn_multiplier: float = 1.0 mup_use_scaling: bool = True mup_embedding_multiplier: float = 10.0 mup_width_multiplier: float = 8.0 rope_embedding_base: float = 1000000 rope_position_scale: float = 1.0 - blocksparse_block_size: Tuple[int] = (64,) + blocksparse_block_size: int = 64 blocksparse_num_local_blocks: int = 16 blocksparse_vert_stride: int = 8 @@ -61,7 +61,6 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads - assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_q_per_kv = n_heads // n_kv_heads @@ -161,7 +160,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -230,7 +229,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -304,16 +303,8 @@ class Model(nn.Module): def layers(self): return self.model.layers - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - def sanitize(self, weights): # Remove unused precomputed rotary freqs return { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py index db6bd4b5..ca20a388 100644 --- a/llms/mlx_lm/models/phimoe.py +++ b/llms/mlx_lm/models/phimoe.py @@ -173,6 +173,7 @@ class PhiMoEModel(nn.Module): class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() + self.model_type = args.model_type self.args = args self.model = PhiMoEModel(args) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True) @@ -208,11 +209,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index bb67615d..865d0d8e 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -168,8 +168,8 @@ class Model(nn.Module): self, x: mx.array, mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: mask = create_attention_mask(x, cache) y = self.transformer(x, mask, cache) @@ -193,11 +193,3 @@ class Model(nn.Module): @property def layers(self): return self.transformer.h - - @property - def head_dim(self): - return self.args.model_dim // self.args.num_heads - - @property - def n_kv_heads(self): - return self.args.num_heads diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index 5d2b7586..090922ae 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn @@ -62,8 +62,8 @@ class Attention(nn.Module): self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + cache: Optional[Any] = None, + ) -> mx.array: bsz, q_len, _ = hidden_states.shape queries = self.q_proj(hidden_states) @@ -127,8 +127,8 @@ class PlamoDecoderLayer(nn.Module): self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> Tuple[Any, ...]: + cache: Optional[Any] = None, + ): # from LlamaDecoder residual = hidden_states @@ -169,8 +169,8 @@ class PlamoModel(nn.Module): def __call__( self, inputs: mx.array, - cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None, - ) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]: + cache: Optional[Any] = None, + ) -> mx.array: h = self.embed_tokens(inputs) mask = create_attention_mask(h, cache) @@ -197,19 +197,11 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, - cache: Optional[List[Tuple[mx.array, mx.array]]] = None, - ) -> Tuple[mx.array, mx.array]: + cache: Optional[Any] = None, + ) -> mx.array: out = self.model(inputs, cache) return self.lm_head(out) @property def layers(self): return self.model.layers.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_attention_heads // self.args.n_shared_head diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 6d2c7bbf..2b69d5ec 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -1,7 +1,6 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Tuple import mlx.core as mx import mlx.nn as nn @@ -149,19 +148,11 @@ class Model(nn.Module): self, x: mx.array, mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: y = self.transformer(x, mask, cache) return self.lm_head(y) @property def layers(self): return self.transformer.h - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_attention_heads diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index b3ce02a3..4e7858de 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -70,7 +70,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -124,7 +124,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -196,11 +196,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index ff7831f3..d199116f 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -2,12 +2,12 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .switch_layers import SwitchGLU @@ -70,7 +70,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -162,7 +162,7 @@ class Qwen2MoeDecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -236,11 +236,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 34750ace..06a307a6 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -7,13 +7,13 @@ from typing import List, Literal, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask +from .cache import MambaCache, RotatingKVCache @dataclass class ModelArgs(BaseModelArgs): model_type: str - hidden_size: int attention_bias: bool conv1d_width: int hidden_size: int @@ -36,59 +36,6 @@ class ModelArgs(BaseModelArgs): self.block_types = self._block_types -def create_window_causal_mask(N: int, window_size: int): - inds = mx.arange(N) - linds = inds[:, None] - rinds = inds[None] - mask = (linds < rinds) | (linds > rinds + window_size) - return mask * -1e9 - - -class RecurrentCache: - - def __init__(self): - self._cache = (None, None) - - def __getitem__(self, idx): - return self._cache[idx] - - def update(self, conv_state, recurrent_state): - self._cache = (conv_state, recurrent_state) - - def state(self): - return self._cache - - -class WindowKVCache: - - def __init__(self, window_size): - self.keys = None - self.values = None - self.offset = 0 - self.window_size = window_size - - def update_and_fetch(self, keys, values): - # TODO consider using rotating buffer here - # especially for very long generations - def _update(x, v): - t = x.shape[2] - self.window_size - if t > 0: - x = x[..., t:, :] - return mx.concatenate([x, v], axis=2) - - self.offset += keys.shape[2] - if self.keys is None: - self.keys = keys - self.values = values - else: - self.keys = _update(self.keys, keys) - self.values = _update(self.values, values) - return self.keys, self.values - - def state(self): - return self.keys, self.values - - class RMSNorm(nn.Module): def __init__(self, dims: int, eps: float = 1e-5): super().__init__() @@ -136,31 +83,22 @@ class Conv1d(nn.Module): kernel_size: int, ): super().__init__() - self.weight = mx.zeros((kernel_size, channels)) + self.weight = mx.zeros((channels, kernel_size, 1)) self.bias = mx.zeros((channels,)) def __call__(self, x, cache=None): - w = self.weight.T[..., None] - kw, groups = self.weight.shape - if cache is not None: - l = [] - # Pad the cache if needed - if cache.shape[1] < kw - 1: - l.append( - mx.zeros( - (x.shape[0], kw - 1 - cache.shape[1], groups), dtype=x.dtype - ) - ) - l.extend([cache, x]) - x = mx.concatenate(l, axis=1) - y = (x * w.swapaxes(0, 2)).sum(axis=1, keepdims=True) - else: - y = mx.conv_general(x, w, padding=([kw - 1], [0]), groups=groups) + B, L, C = x.shape + groups, K, _ = self.weight.shape - # The cache is always kw - 1 - cache = x[:, max(x.shape[1] - kw + 1, 0) :, :] + if cache is not None: + x = mx.concatenate([cache, x], axis=1) + else: + x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) + + y = mx.conv_general(x, self.weight, groups=groups) y = y + self.bias - return y, cache + + return y, x[:, -K + 1 :, :] class RGLRU(nn.Module): @@ -269,19 +207,9 @@ class RecurrentBlock(nn.Module): # x branch. x = self.linear_x(x) if cache is None: - conv_state, recurrent_state = (None, None) - else: - conv_state, recurrent_state = cache[0], cache[1] - x, conv_state = self.conv_1d( - x=x, - cache=conv_state, - ) - x, recurrent_state = self.rg_lru( - x=x, - cache=recurrent_state, - ) - if cache is not None: - cache.update(conv_state, recurrent_state) + cache = [None, None] + x, cache[0] = self.conv_1d(x=x, cache=cache[0]) + x, cache[1] = self.rg_lru(x=x, cache=cache[1]) x = x * y x = self.linear_out(x) @@ -467,12 +395,14 @@ class Griffin(nn.Module): if self.scale_by_sqrt_dim: x = x * math.sqrt(x.shape[-1]) - mask = None - if x.shape[1] > 1: - mask = create_window_causal_mask( - x.shape[1], self.config.attention_window_size - ) - mask = mask.astype(x.dtype) + if cache is None: + cache = [None] * len(self.layers) + + for i, block in enumerate(self.layers): + if block.temporal_block_type != "recurrent": + mask_cache = [cache[i]] + + mask = create_attention_mask(x, mask_cache) for i, block in enumerate(self.layers): x = block(x, mask=mask, cache=cache[i]) @@ -485,6 +415,7 @@ class Model(nn.Module): def __init__(self, config): self.args = config self.model = Griffin(config) + self.model_type = config.model_type self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def __call__(self, tokens: mx.array, cache=None) -> mx.array: @@ -508,10 +439,9 @@ class Model(nn.Module): return self.model.layers def sanitize(self, weights): - # Remove unused precomputed rotary freqs for k, v in weights.items(): if "conv_1d.weight" in k and v.ndim == 3: - weights[k] = v.squeeze(1).T + weights[k] = v.moveaxis(2, 1) if "lm_head.weight" not in weights: self.pop("lm_head") return weights @@ -520,7 +450,7 @@ class Model(nn.Module): cache = [] for layer in self.layers: if layer.temporal_block_type == "recurrent": - cache.append(RecurrentCache()) + cache.append(MambaCache()) else: - cache.append(WindowKVCache(self.args.attention_window_size)) + cache.append(RotatingKVCache(max_size=self.args.attention_window_size)) return cache diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index b340de28..11202b02 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -2,7 +2,6 @@ import math from dataclasses import dataclass -from typing import Tuple import mlx.core as mx import mlx.nn as nn @@ -198,8 +197,8 @@ class Model(nn.Module): self, x: mx.array, mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: mask = create_attention_mask(x, cache) y = self.model(x, mask, cache) return self.lm_head(y) @@ -207,11 +206,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 9cec0e39..ce0a2ec5 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -45,7 +45,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -100,7 +100,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -164,11 +164,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 159efb54..8b9fb27d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -18,7 +18,7 @@ from mlx.utils import tree_flatten from transformers import PreTrainedTokenizer # Local imports -from .models.base import KVCache, RotatingKVCache +from .models import base, cache from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model @@ -124,26 +124,6 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float) return logits -def make_kv_caches( - model: nn.Module, max_kv_size: Optional[int] = None -) -> List[Union[KVCache, RotatingKVCache]]: - if hasattr(model, "make_cache"): - return model.make_cache() - - kv_heads = ( - [model.n_kv_heads] * len(model.layers) - if isinstance(model.n_kv_heads, int) - else model.n_kv_heads - ) - if max_kv_size is not None: - return [ - RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4) - for n in kv_heads - ] - else: - return [KVCache(model.head_dim, n) for n in kv_heads] - - def generate_step( prompts: mx.array, model: nn.Module, @@ -155,7 +135,7 @@ def generate_step( min_tokens_to_keep: int = 1, prefill_step_size: int = 512, max_kv_size: Optional[int] = None, - cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None, + prompt_cache: Optional[Any] = None, logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: @@ -180,6 +160,8 @@ def generate_step( prefill_step_size (int): Step size for processing the prompt. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. logit_bias (dictionary, optional): Additive logit bias. logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed @@ -243,20 +225,13 @@ def generate_step( tokens = None # Create the KV cache for generation - cache = make_kv_caches(model, max_kv_size) - - if cache_history is not None: - if len(cache_history) != len(cache): - raise ValueError("Wrong number of layers in the cache history") - - # Set the history in the cache objects and evaluate them to prepare for - # generation. - for c, h in zip(cache, cache_history): - c.update_and_fetch(h[0], h[1]) - mx.eval([c.state for c in cache]) + if prompt_cache is None: + prompt_cache = cache.make_prompt_cache(model, max_kv_size) + elif len(prompt_cache) != len(model.layers): + raise ValueError("Wrong number of layers in the prompt cache.") def _step(y): - logits = model(y, cache=cache) + logits = model(y, cache=prompt_cache) logits = logits[:, -1, :] if logits_processor: @@ -270,7 +245,7 @@ def generate_step( return y, logprobs while y.shape[1] > prefill_step_size: - model(y[:, :prefill_step_size], cache=cache) + model(y[:, :prefill_step_size], cache=prompt_cache) mx.eval([c.state for c in cache]) y = y[:, prefill_step_size:] @@ -312,9 +287,9 @@ def stream_generate( detokenizer = tokenizer.detokenizer detokenizer.reset() - for (token, _), n in zip( - generate_step(prompt_tokens[None], model, **kwargs), + for _, (token, _) in zip( range(max_tokens), + generate_step(prompt_tokens, model, **kwargs), ): token = token.item() if token == tokenizer.eos_token_id: @@ -365,9 +340,9 @@ def generate( tic = time.perf_counter() detokenizer.reset() - for (token, logprobs), n in zip( - generate_step(prompt_tokens[None], model, **kwargs), + for n, (token, logprobs) in zip( range(max_tokens), + generate_step(prompt_tokens[None], model, **kwargs), ): token = token.item() if n == 0: diff --git a/llms/setup.py b/llms/setup.py index e2cfe0cd..1c696dc0 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -32,6 +32,7 @@ setup( entry_points={ "console_scripts": [ "mlx_lm.cache_prompt = mlx_lm.cache_prompt:main", + "mlx_lm.chat = mlx_lm.chat:main", "mlx_lm.convert = mlx_lm.convert:main", "mlx_lm.fuse = mlx_lm.fuse:main", "mlx_lm.generate = mlx_lm.generate:main", diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index cd7e7fd0..1efde5ae 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -1,17 +1,15 @@ # Copyright © 2024 Apple Inc. - import unittest import mlx.core as mx from mlx.utils import tree_map -from mlx_lm.models.base import KVCache, RotatingKVCache -from mlx_lm.utils import make_kv_caches +from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache class TestModels(unittest.TestCase): def test_kv_cache(self): - cache = KVCache(32, 4) + cache = KVCache() k = mx.ones((1, 4, 1, 32), mx.float16) v = mx.ones((1, 4, 1, 32), mx.float16) @@ -32,7 +30,7 @@ class TestModels(unittest.TestCase): def test_rotating_kv_cache(self): b, h, d = 1, 2, 32 - cache = RotatingKVCache(d, h, max_size=8, step=4) + cache = RotatingKVCache(max_size=8, step=4) k = mx.random.uniform(shape=(b, h, 2, d)) v = mx.random.uniform(shape=(b, h, 2, d)) @@ -65,7 +63,7 @@ class TestModels(unittest.TestCase): idx %= 8 # Try with nonzero keep - cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2) + cache = RotatingKVCache(max_size=8, step=4, keep=2) # Check a large update k = mx.random.uniform(shape=(b, h, 20, d)) @@ -88,6 +86,46 @@ class TestModels(unittest.TestCase): if idx >= 8: idx = 2 + def test_rotating_kv_cache_chat_mode(self): + # Test that the rotating kv cache can handle + # alternating prompt/prefill with generation + d = 4 + h = 2 + cache = RotatingKVCache(max_size=18, step=4) + + x = mx.random.uniform(shape=(1, h, 8, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 8) + self.assertEqual(cache.offset, 8) + + x = mx.random.uniform(shape=(1, h, 1, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 9) + self.assertEqual(cache.offset, 9) + self.assertTrue(mx.allclose(x, k[..., 8:9, :])) + + x = mx.random.uniform(shape=(1, h, 2, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 11) + self.assertEqual(cache.offset, 11) + self.assertTrue(mx.allclose(x, k[..., 9:11, :])) + + x = mx.random.uniform(shape=(1, h, 3, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(k.shape[2], 14) + self.assertEqual(cache.offset, 14) + self.assertTrue(mx.allclose(x, k[..., 11:14, :])) + + x = mx.random.uniform(shape=(1, h, 6, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(cache.offset, 20) + self.assertTrue(mx.allclose(x, k[..., -6:, :])) + + x = mx.random.uniform(shape=(1, h, 2, d)) + k, v = cache.update_and_fetch(x, x) + self.assertEqual(cache.offset, 22) + self.assertTrue(mx.allclose(x, k[..., -2:, :])) + def model_test_runner(self, model, model_type, vocab_size, num_layers): self.assertEqual(len(model.layers), num_layers) @@ -101,7 +139,7 @@ class TestModels(unittest.TestCase): self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) - cache = make_kv_caches(model) + cache = make_prompt_cache(model) outputs = model(inputs, cache) self.assertEqual(outputs.shape, (1, 2, vocab_size)) self.assertEqual(outputs.dtype, t) @@ -549,6 +587,179 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_deepseek(self): + from mlx_lm.models import deepseek + + args = deepseek.ModelArgs( + model_type="deepseek", + vocab_size=1024, + hidden_size=128, + intermediate_size=256, + moe_intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=4, + ) + model = deepseek.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_deepseek_v2(self): + from mlx_lm.models import deepseek_v2 + + args = deepseek_v2.ModelArgs( + model_type="deepseek_v2", + vocab_size=1024, + hidden_size=128, + intermediate_size=256, + moe_intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + kv_lora_rank=4, + q_lora_rank=4, + qk_rope_head_dim=32, + v_head_dim=16, + qk_nope_head_dim=32, + rope_scaling={ + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn", + }, + ) + model = deepseek_v2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_gemma2(self): + from mlx_lm.models import gemma2 + + args = gemma2.ModelArgs( + model_type="gemma2", + hidden_size=128, + num_hidden_layers=4, + intermediate_size=256, + num_attention_heads=2, + head_dim=32, + rms_norm_eps=1e-4, + vocab_size=1024, + num_key_value_heads=2, + ) + model = gemma2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_gpt_bigcode(self): + from mlx_lm.models import gpt_bigcode + + args = gpt_bigcode.ModelArgs( + model_type="gpt_bigcode", + n_embd=128, + n_layer=128, + n_inner=256, + n_head=4, + n_positions=1000, + layer_norm_epsilon=1e-5, + vocab_size=1024, + ) + model = gpt_bigcode.Model(args) + self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer) + + def test_nemotron(self): + from mlx_lm.models import nemotron + + args = nemotron.ModelArgs( + model_type="nemotron", + hidden_size=128, + hidden_act="gelu", + num_hidden_layers=4, + intermediate_size=256, + num_attention_heads=4, + norm_eps=1e-5, + vocab_size=1024, + num_key_value_heads=2, + ) + model = nemotron.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_phi3small(self): + from mlx_lm.models import phi3small + + args = phi3small.ModelArgs( + model_type="phi3small", + hidden_size=128, + dense_attention_every_n_layers=2, + ff_intermediate_size=256, + gegelu_limit=1.0, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + layer_norm_epsilon=1e-4, + vocab_size=1000, + ) + model = phi3small.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_phimoe(self): + from mlx_lm.models import phimoe + + args = phimoe.ModelArgs( + model_type="phimoe", + vocab_size=320, + hidden_size=128, + intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=4, + rope_scaling={ + "long_factor": [1.0] * 16, + "long_mscale": 1.243163121016122, + "original_max_position_embeddings": 4096, + "short_factor": [1.0] * 16, + "short_mscale": 1.243163121016122, + "type": "longrope", + }, + ) + model = phimoe.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_recurrent_gemma(self): + from mlx_lm.models import recurrent_gemma + + args = recurrent_gemma.ModelArgs( + model_type="recurrent_gemma", + hidden_size=128, + attention_bias=False, + conv1d_width=3, + intermediate_size=256, + logits_soft_cap=1.0, + num_attention_heads=4, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-4, + rope_theta=1000, + attention_window_size=1024, + vocab_size=1000, + block_types=["recurrent", "recurrent", "attention"], + ) + model = recurrent_gemma.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + if __name__ == "__main__": unittest.main() diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py new file mode 100644 index 00000000..3c1ef49b --- /dev/null +++ b/llms/tests/test_prompt_cache.py @@ -0,0 +1,220 @@ +# Copyright © 2024 Apple Inc. + +import os +import tempfile +import unittest + +import mlx.core as mx +from mlx_lm.models.cache import ( + KVCache, + MambaCache, + RotatingKVCache, + load_prompt_cache, + make_prompt_cache, + save_prompt_cache, + trim_prompt_cache, +) +from mlx_lm.utils import generate_step, load + +HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" + + +class TestPromptCache(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.test_dir_fid = tempfile.TemporaryDirectory() + cls.test_dir = cls.test_dir_fid.name + + @classmethod + def tearDownClass(cls): + cls.test_dir_fid.cleanup() + + def test_save_load(self): + cache = [KVCache() for _ in range(4)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 4)) + c.update_and_fetch(x, x) + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + self.assertTrue(len(cache), len(loaded_cache)) + for c, lc in zip(cache, loaded_cache): + self.assertEqual(c.offset, lc.offset) + self.assertTrue(mx.array_equal(c.state[0], lc.state[0])) + self.assertTrue(mx.array_equal(c.state[1], lc.state[1])) + + # Test with metadata + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + metadata = {"a": "b", "c": "d"} + save_prompt_cache(cache_file, cache, metadata) + _, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True) + self.assertEqual(metadata, loaded_metadata) + + def test_save_load_rotating_cache(self): + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + + # Test with rotating cache + cache = [RotatingKVCache(max_size=8, keep=2) for _ in range(4)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 4)) + c.update_and_fetch(x, x) + + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + self.assertTrue(len(cache), len(loaded_cache)) + for c, lc in zip(cache, loaded_cache): + self.assertEqual(c.offset, lc.offset) + self.assertEqual(c.keep, lc.keep) + self.assertEqual(c.max_size, lc.max_size) + self.assertEqual(c.step, lc.step) + self.assertTrue(mx.array_equal(c.state[0], lc.state[0])) + self.assertTrue(mx.array_equal(c.state[1], lc.state[1])) + + # Do a couple single token updates to get a rotation + for _ in range(2): + for c in cache: + x = mx.random.uniform(shape=(1, 8, 1, 4)) + c.update_and_fetch(x, x) + + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + + for c, lc in zip(cache, loaded_cache): + x = mx.random.uniform(shape=(1, 8, 1, 4)) + k, v = c.update_and_fetch(x, x) + lk, lv = lc.update_and_fetch(x, x) + self.assertEqual(c.offset, lc.offset) + self.assertTrue(mx.array_equal(k, lk)) + self.assertTrue(mx.array_equal(v, lv)) + + def test_save_load_mixed_cache(self): + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + + cache = [MambaCache(), KVCache(), RotatingKVCache(8), MambaCache()] + for c in cache: + if isinstance(c, MambaCache): + c[0] = mx.random.uniform(shape=(4, 4, 4)) + c[1] = mx.random.uniform(shape=(4, 4, 4)) + else: + x = mx.random.uniform(shape=(4, 4, 7, 4)) + y = mx.random.uniform(shape=(4, 4, 7, 4)) + c.update_and_fetch(x, y) + + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + for c, lc in zip(cache, loaded_cache): + if isinstance(c, MambaCache): + self.assertTrue(mx.array_equal(c[0], lc[0])) + self.assertTrue(mx.array_equal(c[1], lc[1])) + else: + x = mx.random.uniform(shape=(4, 4, 1, 4)) + y = mx.random.uniform(shape=(4, 4, 1, 4)) + k, v = c.update_and_fetch(x, y) + lk, lv = lc.update_and_fetch(x, y) + self.assertEqual(c.offset, lc.offset) + self.assertTrue(mx.array_equal(k, lk)) + self.assertTrue(mx.array_equal(v, lv)) + + def test_cache_with_generate(self): + model, tokenizer = load(HF_MODEL_PATH) + prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] + results = zip(range(4), generate_step(prompt, model)) + toks, all_logits = zip(*(r[1] for r in results)) + + prompt_cache = make_prompt_cache(model) + i = 0 + for _, (tok, logits) in zip( + range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + ): + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i])) + i += 1 + + for _, (tok, logits) in zip( + range(1), + generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + ): + i += 1 + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i])) + + def test_trim_cache(self): + cache = [KVCache() for _ in range(2)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 4)) + c.update_and_fetch(x, x) + + # Trim + num_trimmed = trim_prompt_cache(cache, 7) + self.assertEqual(num_trimmed, 7) + + # Trim more tokens than remain + num_trimmed = trim_prompt_cache(cache, 4) + self.assertEqual(num_trimmed, 3) + + # Can't trim mamba cache + cache = [MambaCache() for _ in range(2)] + for c in cache: + c.state = mx.zeros((5, 5)) + num_trimmed = trim_prompt_cache(cache, 7) + self.assertEqual(num_trimmed, 0) + + # All cache's have to be trimmable + cache = [MambaCache(), KVCache()] + cache[0].state = mx.zeros((5, 5)) + x = mx.random.uniform(shape=(1, 8, 10, 4)) + cache[1].update_and_fetch(x, x) + num_trimmed = trim_prompt_cache(cache, 1) + self.assertEqual(num_trimmed, 0) + + cache = [RotatingKVCache(max_size=6) for _ in range(2)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 5, 4)) + c.update_and_fetch(x, x) + + num_trimmed = trim_prompt_cache(cache, 4) + self.assertEqual(num_trimmed, 4) + + # Can't trim fixed-size KV cache after processing + # more than max_kv_size tokens + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 4)) + c.update_and_fetch(x, x) + + num_trimmed = trim_prompt_cache(cache, 4) + self.assertEqual(num_trimmed, 0) + + def test_trim_cache_with_generate(self): + model, tokenizer = load(HF_MODEL_PATH) + prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] + + prompt_cache = make_prompt_cache(model) + + # Generate one token so we process the full prompt + last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache)) + last_tok = mx.array([last_tok]) + + # Generate two more tokens + results = zip( + range(2), generate_step(last_tok, model, prompt_cache=prompt_cache) + ) + toks, all_logits = zip(*(r[1] for r in results)) + + # To get back to the cache just after processing the prompt, + # trim by 3 tokens + trim_prompt_cache(prompt_cache, 3) + + # Generate the same thing again + results = zip( + range(2), generate_step(last_tok, model, prompt_cache=prompt_cache) + ) + second_toks, second_all_logits = zip(*(r[1] for r in results)) + self.assertEqual(toks, second_toks) + self.assertTrue( + all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits)) + ) + + +if __name__ == "__main__": + unittest.main()