From 782f5a71b787381a843b91375a67c286ced6013a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sat, 5 Oct 2024 14:49:39 -0700 Subject: [PATCH] reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat --- .gitignore | 3 + llms/README.md | 4 +- llms/mlx_lm/cache_prompt.py | 19 +- llms/mlx_lm/examples/chat.py | 50 +++++ llms/mlx_lm/examples/generate_response.py | 2 + llms/mlx_lm/generate.py | 57 ++--- llms/mlx_lm/models/base.py | 163 +------------- llms/mlx_lm/models/cache.py | 257 ++++++++++++++++++++++ llms/mlx_lm/models/cohere.py | 14 +- llms/mlx_lm/models/dbrx.py | 16 +- llms/mlx_lm/models/deepseek.py | 20 +- llms/mlx_lm/models/deepseek_v2.py | 36 +-- llms/mlx_lm/models/gemma.py | 14 +- llms/mlx_lm/models/gemma2.py | 20 +- llms/mlx_lm/models/gpt2.py | 14 +- llms/mlx_lm/models/gpt_bigcode.py | 14 +- llms/mlx_lm/models/gpt_neox.py | 14 +- llms/mlx_lm/models/internlm2.py | 14 +- llms/mlx_lm/models/llama.py | 18 +- llms/mlx_lm/models/mamba.py | 18 +- llms/mlx_lm/models/minicpm.py | 14 +- llms/mlx_lm/models/mixtral.py | 14 +- llms/mlx_lm/models/nemotron.py | 18 +- llms/mlx_lm/models/olmo.py | 18 +- llms/mlx_lm/models/openelm.py | 14 +- llms/mlx_lm/models/phi.py | 12 +- llms/mlx_lm/models/phi3.py | 16 +- llms/mlx_lm/models/phi3small.py | 21 +- llms/mlx_lm/models/phimoe.py | 9 +- llms/mlx_lm/models/phixtral.py | 12 +- llms/mlx_lm/models/plamo.py | 26 +-- llms/mlx_lm/models/qwen.py | 13 +- llms/mlx_lm/models/qwen2.py | 16 +- llms/mlx_lm/models/qwen2_moe.py | 16 +- llms/mlx_lm/models/recurrent_gemma.py | 122 ++-------- llms/mlx_lm/models/stablelm.py | 13 +- llms/mlx_lm/models/starcoder2.py | 16 +- llms/mlx_lm/utils.py | 53 ++--- llms/tests/test_models.py | 182 ++++++++++++++- llms/tests/test_prompt_cache.py | 143 ++++++++++++ 40 files changed, 824 insertions(+), 691 deletions(-) create mode 100644 llms/mlx_lm/examples/chat.py create mode 100644 llms/mlx_lm/models/cache.py create mode 100644 llms/tests/test_prompt_cache.py diff --git a/.gitignore b/.gitignore index f3dfe929..45445fc8 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ __pycache__/ # C extensions *.so +# Vim +*.swp + # Distribution / packaging .Python build/ diff --git a/llms/README.md b/llms/README.md index 75677865..5c59c796 100644 --- a/llms/README.md +++ b/llms/README.md @@ -155,14 +155,14 @@ different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example: cat prompt.txt | mlx_lm.cache_prompt \ --model mistralai/Mistral-7B-Instruct-v0.3 \ --prompt - \ - --kv-cache-file mistral_prompt.safetensors + --prompt-cache-file mistral_prompt.safetensors ``` Then use the cached prompt with `mlx_lm.generate`: ``` mlx_lm.generate \ - --kv-cache-file mistral_prompt.safetensors \ + --prompt-cache-file mistral_prompt.safetensors \ --prompt "\nSummarize the above text." ``` diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 9829efb4..04e75a3e 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -7,13 +7,14 @@ import time import mlx.core as mx -from .utils import load, make_kv_caches +from .models.cache import make_prompt_cache, save_prompt_cache +from .utils import load def setup_arg_parser(): """Set up and return the argument parser.""" parser = argparse.ArgumentParser( - description="Cache the KV cache of a prompt to be reused with mlx_lm.generate" + description="Cache the state of a prompt to be reused with mlx_lm.generate" ) parser.add_argument( "--model", @@ -60,7 +61,9 @@ def setup_arg_parser(): help="Set the maximum key-value cache size", ) parser.add_argument( - "--kv-cache-file", help="The file to save the KV caches in", required=True + "--prompt-cache-file", + help="The file to save the prompt cache in", + required=True, ) parser.add_argument( "--prompt", @@ -115,7 +118,7 @@ def main(): else: prompt = args.prompt - cache = make_kv_caches(model, args.max_kv_size) + cache = make_prompt_cache(model, args.max_kv_size) y = mx.array(tokenizer.encode(prompt)) # Process the prompt @@ -137,16 +140,12 @@ def main(): print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") print("Saving...") - cache_dict = {} - for i, c in enumerate(cache): - cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :] - cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :] metadata = {} metadata["model"] = args.model metadata["chat_template"] = tokenizer.chat_template metadata["tokenizer_config"] = json.dumps(tokenizer_config) - metadata["max_kv_size"] = str(args.max_kv_size) - mx.save_safetensors(args.kv_cache_file, cache_dict, metadata) + print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB") + save_prompt_cache(args.prompt_cache_file, cache, metadata) if __name__ == "__main__": diff --git a/llms/mlx_lm/examples/chat.py b/llms/mlx_lm/examples/chat.py new file mode 100644 index 00000000..d4fc5bca --- /dev/null +++ b/llms/mlx_lm/examples/chat.py @@ -0,0 +1,50 @@ +# Copyright © 2024 Apple Inc. + +""" +An example of a multi-turn chat with prompt caching. +""" + +from mlx_lm import generate, load +from mlx_lm.models.cache import make_prompt_cache + +model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit") + +# Make the initial prompt cache for the model +prompt_cache = make_prompt_cache(model) + +# User turn +prompt = "Hi my name is ." +messages = [{"role": "user", "content": prompt}] +prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) + +# Assistant response +response = generate( + model, + tokenizer, + prompt=prompt, + verbose=True, + max_tokens=1024, + temp=0.0, + prompt_cache=prompt_cache, +) +messages.append({"role": "assistant", "content": response}) + +# User turn +prompt = "What's my name?" +messages = [{"role": "user", "content": prompt}] +prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) + +# Assistant response +response = generate( + model, + tokenizer, + prompt=prompt, + verbose=True, + max_tokens=1024, + temp=0.0, + prompt_cache=prompt_cache, +) diff --git a/llms/mlx_lm/examples/generate_response.py b/llms/mlx_lm/examples/generate_response.py index af599c1b..25730617 100644 --- a/llms/mlx_lm/examples/generate_response.py +++ b/llms/mlx_lm/examples/generate_response.py @@ -1,3 +1,5 @@ +# Copyright © 2024 Apple Inc. + from mlx_lm import generate, load # Specify the checkpoint diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 537bd853..82c53d25 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -6,6 +6,7 @@ import sys import mlx.core as mx +from .models.cache import load_prompt_cache from .utils import generate, load DEFAULT_PROMPT = "hello" @@ -96,7 +97,7 @@ def setup_arg_parser(): default=None, ) parser.add_argument( - "--kv-cache-file", + "--prompt-cache-file", type=str, default=None, help="A file containing saved KV caches to avoid recomputing them", @@ -131,24 +132,6 @@ def colorprint_by_t0(s, t0): colorprint(color, s) -def load_kv_cache_from_file(kv_cache_file): - if kv_cache_file is None: - return None, None - - kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True) - cache_per_layer = {} - for k, x in kv_cache.items(): - layer, kv_type = k.split("_") - if layer not in cache_per_layer: - cache_per_layer[layer] = {} - cache_per_layer[layer][kv_type] = x - - cache_history = [None] * len(cache_per_layer) - for layer, c in cache_per_layer.items(): - cache_history[int(layer)] = (c["keys"], c["values"]) - return cache_history, metadata - - def main(): parser = setup_arg_parser() args = parser.parse_args() @@ -158,22 +141,32 @@ def main(): if args.cache_limit_gb is not None: mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - # Load the kv cache and metadata if a kv cache file is provided - cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file) + # Load the prompt cache and metadata if a cache file is provided + using_cache = args.prompt_cache_file is not None + if using_cache: + prompt_cache, metadata = load_prompt_cache( + args.prompt_cache_file, return_metadata=True + ) # Building tokenizer_config tokenizer_config = ( - {} if cache_history is None else json.loads(metadata["tokenizer_config"]) + {} if not using_cache else json.loads(metadata["tokenizer_config"]) ) if args.trust_remote_code: tokenizer_config["trust_remote_code"] = True if args.eos_token is not None: tokenizer_config["eos_token"] = args.eos_token - # If no model path is provided then use the one in the kv cache history model_path = args.model - if cache_history is not None and model_path is None: - model_path = metadata["model"] + if using_cache: + if model_path is None: + model_path = metadata["model"] + elif model_path != metadata["model"]: + raise ValueError( + f"Providing a different model ({model_path}) than that " + f"used to create the prompt cache ({metadata['model']}) " + "is an error." + ) model, tokenizer = load( model_path, @@ -184,7 +177,7 @@ def main(): if args.use_default_chat_template: if tokenizer.chat_template is None: tokenizer.chat_template = tokenizer.default_chat_template - elif cache_history is not None: + elif using_cache: tokenizer.chat_template = metadata["chat_template"] if not args.ignore_chat_template and ( @@ -203,7 +196,7 @@ def main(): # Treat the prompt as a suffix assuming that the prefix is in the # stored kv cache. - if cache_history is not None: + if using_cache: test_prompt = tokenizer.apply_chat_template( [{"role": "user", "content": ""}], tokenize=False, @@ -217,12 +210,6 @@ def main(): raise ValueError("Cannot use --colorize with --verbose=False") formatter = colorprint_by_t0 if args.colorize else None - # Determine the max kv size from the kv cache or passed arguments - max_kv_size = args.max_kv_size - if cache_history is not None: - max_kv_size = metadata["max_kv_size"] - max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None - response = generate( model, tokenizer, @@ -232,8 +219,8 @@ def main(): formatter=formatter, temp=args.temp, top_p=args.top_p, - max_kv_size=max_kv_size, - cache_history=cache_history, + max_kv_size=args.max_kv_size, + prompt_cache=prompt_cache if using_cache else None, ) if not args.verbose: print(response) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 75f19642..3628a808 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -2,153 +2,9 @@ import inspect from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, Optional import mlx.core as mx -import mlx.nn as nn - - -class KVCache: - - def __init__(self, head_dim, n_kv_heads): - self.n_kv_heads = n_kv_heads - if isinstance(head_dim, int): - self.k_head_dim = self.v_head_dim = head_dim - elif isinstance(head_dim, tuple) and len(head_dim) == 2: - self.k_head_dim, self.v_head_dim = head_dim - else: - raise ValueError("head_dim must be an int or a tuple of two ints") - self.keys = None - self.values = None - self.offset = 0 - self.step = 256 - - def update_and_fetch(self, keys, values): - prev = self.offset - if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: - B = keys.shape[0] - n_steps = (self.step + keys.shape[2] - 1) // self.step - k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim) - v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim) - new_k = mx.zeros(k_shape, keys.dtype) - new_v = mx.zeros(v_shape, values.dtype) - if self.keys is not None: - if prev % self.step != 0: - self.keys = self.keys[..., :prev, :] - self.values = self.values[..., :prev, :] - self.keys = mx.concatenate([self.keys, new_k], axis=2) - self.values = mx.concatenate([self.values, new_v], axis=2) - else: - self.keys, self.values = new_k, new_v - - self.offset += keys.shape[2] - self.keys[..., prev : self.offset, :] = keys - self.values[..., prev : self.offset, :] = values - return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] - - @property - def state(self): - return self.keys, self.values - - -class RotatingKVCache: - - def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256): - self.n_kv_heads = n_kv_heads - if isinstance(head_dim, int): - self.k_head_dim = self.v_head_dim = head_dim - elif isinstance(head_dim, tuple) and len(head_dim) == 2: - self.k_head_dim, self.v_head_dim = head_dim - else: - raise ValueError("head_dim must be an int or a tuple of two ints") - self.keep = keep - self.keys = None - self.values = None - self.offset = 0 - self.max_size = max_size - self.step = step - self._idx = 0 - - def _trim(self, trim_size, v, append=None): - to_cat = [] - if trim_size > 0: - to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]] - else: - to_cat = [v] - if append is not None: - to_cat.append(append) - return mx.concatenate(to_cat, axis=2) - - def _update_concat(self, keys, values): - if self.keys is None: - self.keys = keys - self.values = values - else: - if self._idx < self.keys.shape[2]: - self.keys = self.keys[..., : self._idx, :] - self.values = self.values[..., : self._idx, :] - - # The largest size is self.max_size + S - 1 to ensure - # every token gets at least self.max_size context - trim_size = self._idx - self.max_size + 1 - self.keys = self._trim(trim_size, self.keys, keys) - self.values = self._trim(trim_size, self.values, values) - self.offset += keys.shape[2] - self._idx = self.keys.shape[2] - return self.keys, self.values - - def _update_in_place(self, keys, values): - # May not have hit the max size yet, so potentially - # keep growing the cache - B, _, S = keys.shape[:3] - prev = self.offset - if self.keys is None or ( - prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size - ): - new_size = min(self.step, self.max_size - prev) - k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim) - v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim) - new_k = mx.zeros(k_shape, keys.dtype) - new_v = mx.zeros(v_shape, values.dtype) - if self.keys is not None: - self.keys = mx.concatenate([self.keys, new_k], axis=2) - self.values = mx.concatenate([self.values, new_v], axis=2) - else: - self.keys, self.values = new_k, new_v - self._idx = prev - - # Trim if needed - trim_size = self.keys.shape[2] - self.max_size - if trim_size > 0: - self.keys = self._trim(trim_size, self.keys) - self.values = self._trim(trim_size, self.values) - self._idx = self.max_size - - # Rotate - if self._idx == self.max_size: - self._idx = self.keep - - # Assign - self.keys[..., self._idx : self._idx + S, :] = keys - self.values[..., self._idx : self._idx + S, :] = values - self.offset += S - self._idx += S - - # If the buffer is not full, slice off the end - if self.offset < self.max_size: - return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] - return self.keys, self.values - - def update_and_fetch(self, keys, values): - S = keys.shape[2] - if S == 1 or (self.keys is not None and S < (self.keys.shape[2] - self._idx)): - return self._update_in_place(keys, values) - - return self._update_concat(keys, values) - - @property - def state(self): - return self.keys, self.values @dataclass @@ -164,25 +20,30 @@ class BaseModelArgs: ) -def create_additive_causal_mask(N: int, offset: int = 0): +def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None): rinds = mx.arange(offset + N) linds = mx.arange(offset, offset + N) if offset else rinds - mask = linds[:, None] < rinds[None] + linds = linds[:, None] + rinds = rinds[None] + mask = linds < rinds + if window_size is not None: + mask = mask | (linds > rinds + window_size) return mask * -1e9 def create_attention_mask(h: mx.array, cache: Optional[Any] = None): T = h.shape[1] if T > 1: + window_size = None + offset = 0 if cache is not None and cache[0] is not None: c = cache[0] - if isinstance(c, RotatingKVCache): + if hasattr(c, "max_size"): offset = min(c.max_size - 1, c.offset) + window_size = c.max_size else: offset = c.offset - else: - offset = 0 - mask = create_additive_causal_mask(T, offset) + mask = create_causal_mask(T, offset, window_size=window_size) mask = mask.astype(h.dtype) else: mask = None diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py new file mode 100644 index 00000000..7d703ee3 --- /dev/null +++ b/llms/mlx_lm/models/cache.py @@ -0,0 +1,257 @@ +# Copyright © 2023-2024 Apple Inc. + +from typing import Any, Dict, List, Optional + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_flatten, tree_unflatten + + +def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]: + if hasattr(model, "make_cache"): + return model.make_cache() + + num_layers = len(model.layers) + if max_kv_size is not None: + return [ + RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers) + ] + else: + return [KVCache() for _ in range(num_layers)] + + +def save_prompt_cache( + file_name: str, cache: List[Any], metadata: Optional[Dict[str, str]] = None +): + """ + Save a pre-computed prompt cache to a file. + """ + cache_data, cache_info = zip(*(c.state for c in cache)) + cache_data = dict(tree_flatten(cache_data)) + cache_classes = [type(c).__name__ for c in cache] + cache_metadata = [cache_classes, cache_info] + if metadata: + cache_metadata.append(metadata) + cache_metadata = dict(tree_flatten(cache_metadata)) + mx.save_safetensors(file_name, cache_data, cache_metadata) + + +def load_prompt_cache(file_name, return_metadata=False): + """ + Load a prompt cache from a file. + Args: + file_name (str): The ``.safetensors`` file name. + return_metadata (bool): Whether or not to return metadata. Default: + ``False``. + + Returns: + List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and + the metadata if requested. + """ + arrays, cache_metadata = mx.load(file_name, return_metadata=True) + arrays = tree_unflatten(list(arrays.items())) + cache_metadata = tree_unflatten(list(cache_metadata.items())) + classes, info = cache_metadata[:2] + cache = [globals()[c]() for c in classes] + for c, *state in zip(cache, arrays, info): + c.state = state + if return_metadata: + return cache, cache_metadata[2] + return cache + + +class KVCache: + + def __init__(self): + self.keys = None + self.values = None + self.offset = 0 + self.step = 256 + + def update_and_fetch(self, keys, values): + prev = self.offset + if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: + B, n_kv_heads, _, k_head_dim = keys.shape + v_head_dim = values.shape[3] + n_steps = (self.step + keys.shape[2] - 1) // self.step + k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim) + v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim) + new_k = mx.zeros(k_shape, keys.dtype) + new_v = mx.zeros(v_shape, values.dtype) + if self.keys is not None: + if prev % self.step != 0: + self.keys = self.keys[..., :prev, :] + self.values = self.values[..., :prev, :] + self.keys = mx.concatenate([self.keys, new_k], axis=2) + self.values = mx.concatenate([self.values, new_v], axis=2) + else: + self.keys, self.values = new_k, new_v + + self.offset += keys.shape[2] + self.keys[..., prev : self.offset, :] = keys + self.values[..., prev : self.offset, :] = values + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + + @property + def state(self): + if self.offset == self.keys.shape[2]: + return (self.keys, self.values), "" + else: + return ( + self.keys[..., : self.offset, :], + self.values[..., : self.offset, :], + ), "" + + @state.setter + def state(self, v): + self.keys, self.values = v[0] + self.offset = self.keys.shape[2] + + +class RotatingKVCache: + + def __init__(self, max_size=None, keep=0, step=256): + self.keep = keep + self.keys = None + self.values = None + self.offset = 0 + self.max_size = max_size + self.step = step + self._idx = 0 + + def _trim(self, trim_size, v, append=None): + to_cat = [] + if trim_size > 0: + to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]] + else: + to_cat = [v] + if append is not None: + to_cat.append(append) + return mx.concatenate(to_cat, axis=2) + + def _temporal_order(self, v): + """ + Rearrange the cache into temporal order, slicing off the end if unused. + """ + if self._idx == v.shape[2]: + return v + elif self._idx < self.offset: + return mx.concatenate( + [ + v[..., : self.keep, :], + v[..., self._idx :, :], + v[..., self.keep : self._idx, :], + ], + axis=2, + ) + else: + return v[..., : self._idx, :] + + def _update_concat(self, keys, values): + if self.keys is None: + self.keys = keys + self.values = values + else: + # Put the keys/values in temporal order to + # preserve context + self.keys = self._temporal_order(self.keys) + self.values = self._temporal_order(self.values) + + # The largest size is self.max_size + S - 1 to ensure + # every token gets at least self.max_size context + trim_size = self._idx - self.max_size + 1 + self.keys = self._trim(trim_size, self.keys, keys) + self.values = self._trim(trim_size, self.values, values) + self.offset += keys.shape[2] + self._idx = self.keys.shape[2] + return self.keys, self.values + + def _update_in_place(self, keys, values): + # May not have hit the max size yet, so potentially + # keep growing the cache + B, n_kv_heads, S, k_head_dim = keys.shape + prev = self.offset + if self.keys is None or ( + prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size + ): + v_head_dim = values.shape[3] + new_size = min(self.step, self.max_size - prev) + k_shape = (B, n_kv_heads, new_size, k_head_dim) + v_shape = (B, n_kv_heads, new_size, v_head_dim) + new_k = mx.zeros(k_shape, keys.dtype) + new_v = mx.zeros(v_shape, values.dtype) + if self.keys is not None: + self.keys = mx.concatenate([self.keys, new_k], axis=2) + self.values = mx.concatenate([self.values, new_v], axis=2) + else: + self.keys, self.values = new_k, new_v + self._idx = prev + + # Trim if needed + trim_size = self.keys.shape[2] - self.max_size + if trim_size > 0: + self.keys = self._trim(trim_size, self.keys) + self.values = self._trim(trim_size, self.values) + self._idx = self.max_size + + # Rotate + if self._idx == self.max_size: + self._idx = self.keep + + # Assign + self.keys[..., self._idx : self._idx + S, :] = keys + self.values[..., self._idx : self._idx + S, :] = values + self.offset += S + self._idx += S + + # If the buffer is not full, slice off the end + if self.offset < self.max_size: + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + return self.keys, self.values + + def update_and_fetch(self, keys, values): + if keys.shape[2] == 1: + return self._update_in_place(keys, values) + return self._update_concat(keys, values) + + @property + def state(self): + if self.offset < self.keys.shape[2]: + kv_state = (self.keys[..., : self.offset], self.values[..., : self.offset]) + else: + kv_state = (self.keys, self.values) + extra_state = tuple( + map(str, (self.keep, self.max_size, self.step, self.offset, self._idx)) + ) + return kv_state, extra_state + + @state.setter + def state(self, v): + self.keys, self.values = v[0] + self.keep, self.max_size, self.step, self.offset, self._idx = map( + int, + v[1], + ) + + +class MambaCache: + def __init__(self): + self.cache = [None, None] + + def __setitem__(self, idx, value): + self.cache[idx] = value + + def __getitem__(self, idx): + return self.cache[idx] + + @property + def state(self): + return self.cache + + @property + def state(self): + return self.cache, "" + + @state.setter + def state(self, v): + self.cache = v[0] diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index cfcf2945..057c816d 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -69,7 +69,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -129,7 +129,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.input_layernorm(x) attn_h = self.self_attn(h, mask, cache) @@ -190,11 +190,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index f0214549..3b7e83d7 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -49,7 +49,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: qkv = self.Wqkv(x) @@ -92,7 +92,7 @@ class NormAttnNorm(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.attn(self.norm_1(x), mask=mask, cache=cache) x = h + x @@ -179,7 +179,7 @@ class DecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r, h = self.norm_attn_norm(x, mask, cache) out = self.ffn(h) + r @@ -249,11 +249,3 @@ class Model(nn.Module): experts = [(s, sv.T) for s, sv in experts] new_weights.update(experts) return new_weights - - @property - def head_dim(self): - return self.args.d_model // self.args.n_heads - - @property - def n_kv_heads(self): - return self.args.attn_config["kv_n_heads"] diff --git a/llms/mlx_lm/models/deepseek.py b/llms/mlx_lm/models/deepseek.py index dcfa331c..db5199e2 100644 --- a/llms/mlx_lm/models/deepseek.py +++ b/llms/mlx_lm/models/deepseek.py @@ -1,10 +1,10 @@ from dataclasses import dataclass -from typing import Dict, Optional +from typing import Any, Dict, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .switch_layers import SwitchGLU @@ -77,7 +77,7 @@ class DeepseekAttention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, _ = x.shape @@ -188,7 +188,7 @@ class DeepseekDecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -210,7 +210,7 @@ class DeepseekModel(nn.Module): def __call__( self, x: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.embed_tokens(x) mask = create_attention_mask(h, cache) @@ -235,7 +235,7 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ): out = self.model(inputs, cache) return self.lm_head(out) @@ -256,11 +256,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 602a9710..17d061a8 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -2,12 +2,12 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .switch_layers import SwitchGLU @@ -38,7 +38,7 @@ class ModelArgs(BaseModelArgs): max_position_embeddings: int = 2048 rms_norm_eps: float = 1e-6 rope_theta: float = 10000.0 - rope_scaling: Optional[Dict] = None + rope_scaling: Dict = None attention_bias: bool = False @@ -172,12 +172,11 @@ class DeepseekV2Attention(nn.Module): bias=config.attention_bias, ) - if self.config.rope_scaling is not None: - mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) - scaling_factor = self.config.rope_scaling["factor"] - if mscale_all_dim: - mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) - self.scale = self.scale * mscale * mscale + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.scale = self.scale * mscale * mscale rope_kwargs = { key: self.config.rope_scaling[key] @@ -202,7 +201,7 @@ class DeepseekV2Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -347,7 +346,7 @@ class DeepseekV2DecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -370,7 +369,7 @@ class DeepseekV2Model(nn.Module): def __call__( self, x: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: h = self.embed_tokens(x) mask = create_attention_mask(h, cache) @@ -395,7 +394,7 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ): out = self.model(inputs, cache) return self.lm_head(out) @@ -416,14 +415,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return ( - self.args.qk_nope_head_dim + self.args.qk_rope_head_dim, - self.args.v_head_dim, - ) - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index c6150284..61de781e 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -60,7 +60,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -113,7 +113,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -173,11 +173,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.head_dim - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index 1d410a15..ccc327a8 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -64,7 +64,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) @@ -135,13 +135,11 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: - r = self.self_attn(self.input_layernorm(x.astype(mx.float32)), mask, cache) + r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + self.post_attention_layernorm(r) - r = self.mlp(self.pre_feedforward_layernorm(h).astype(mx.float16)).astype( - mx.float32 - ) + r = self.mlp(self.pre_feedforward_layernorm(h)) out = h + self.post_feedforward_layernorm(r) return out @@ -200,11 +198,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.head_dim - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index 8a770936..97d9a8ff 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -46,7 +46,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -100,7 +100,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attn(self.ln_1(x), mask, cache) h = x + r @@ -196,11 +196,3 @@ class Model(nn.Module): @property def layers(self): return self.model.h - - @property - def head_dim(self): - return self.args.n_embd // self.args.n_head - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index 652eb9e4..068046ea 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -57,7 +57,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -114,7 +114,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attn(self.ln_1(x), mask, cache) h = x + r @@ -184,11 +184,3 @@ class Model(nn.Module): @property def layers(self): return self.transformer.h - - @property - def head_dim(self): - return self.args.n_embd // self.args.n_head - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index c2aaa9ea..9f662491 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -60,7 +60,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -120,7 +120,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: residual = x # NeoX runs attention and feedforward network in parallel. @@ -214,11 +214,3 @@ class Model(nn.Module): @property def layers(self): return self.model.h - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index bcc0cf0c..5264cb57 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -116,7 +116,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -171,7 +171,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attention(self.attention_norm(x), mask, cache) h = x + r @@ -236,11 +236,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index c4a947a5..7da6b333 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -171,7 +171,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -233,7 +233,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -303,13 +303,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return ( - self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads - ) - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index 26408426..d2740dc1 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -7,6 +7,7 @@ import mlx.core as mx import mlx.nn as nn from .base import BaseModelArgs +from .cache import MambaCache @dataclass @@ -45,21 +46,6 @@ class ModelArgs(BaseModelArgs): self.time_step_rank = math.ceil(self.hidden_size / 16) -class MambaCache: - def __init__(self): - self.cache = [None, None] - - def __setitem__(self, idx, value): - self.cache[idx] = value - - def __getitem__(self, idx): - return self.cache[idx] - - @property - def state(self): - return self.cache - - class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias=True, padding=0): super().__init__() @@ -223,7 +209,7 @@ class Model(nn.Module): weights[k] = v.moveaxis(2, 1) return weights - def make_cache(self, batch_size: int = 1): + def make_cache(self): return [MambaCache() for _ in range(len(self.layers))] @property diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index df0670be..4ac3c3b4 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -85,7 +85,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ): B, L, _ = x.shape @@ -135,7 +135,7 @@ class DecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers)) @@ -205,11 +205,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index 2db57752..20944fe3 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -2,7 +2,7 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -66,7 +66,7 @@ class MixtralAttention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -138,7 +138,7 @@ class MixtralDecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -215,11 +215,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/nemotron.py b/llms/mlx_lm/models/nemotron.py index ef55d1d7..3ea06e27 100644 --- a/llms/mlx_lm/models/nemotron.py +++ b/llms/mlx_lm/models/nemotron.py @@ -2,12 +2,12 @@ from dataclasses import dataclass from functools import partial -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -94,7 +94,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, _ = x.shape @@ -151,7 +151,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -215,13 +215,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return ( - self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads - ) - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 59849c96..3627df06 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -1,8 +1,8 @@ # Copyright © 2023-2024 Apple Inc. +import sys from dataclasses import dataclass -from sys import exit -from typing import Optional, Tuple +from typing import Any, Optional, Tuple import mlx.core as mx import mlx.nn as nn @@ -13,7 +13,7 @@ try: import hf_olmo except ImportError: print("To run olmo install ai2-olmo: pip install ai2-olmo") - exit(1) + sys.exit(1) @dataclass @@ -68,7 +68,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -98,7 +98,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attend(self.att_norm(x), mask, cache) h = x + r @@ -174,11 +174,3 @@ class Model(nn.Module): @property def layers(self): return self.model.transformer.blocks - - @property - def head_dim(self): - return self.args.d_model // self.args.n_heads - - @property - def n_kv_heads(self): - return self.args.n_heads diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 19d3c027..090e21c6 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -80,7 +80,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -152,7 +152,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.attn(self.attn_norm(x), mask, cache) h = x + r @@ -218,11 +218,3 @@ class Model(nn.Module): @property def layers(self): return self.transformer.layers - - @property - def head_dim(self): - return self.args.head_dim - - @property - def n_kv_heads(self): - return self.args.num_kv_heads diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index fd3fd709..56b383b2 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -162,19 +162,11 @@ class Model(nn.Module): def __call__( self, x: mx.array, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: y = self.model(x, cache) return self.lm_head(y) @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index 112ade7d..9ef76f04 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .su_rope import SuScaledRotaryEmbedding @@ -84,7 +84,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -143,7 +143,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -202,11 +202,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index 665dbc73..6b0759b4 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -3,12 +3,12 @@ import math from dataclasses import dataclass from functools import partial -from typing import Dict, Optional, Tuple, Union +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -22,14 +22,14 @@ class ModelArgs(BaseModelArgs): num_attention_heads: int layer_norm_epsilon: float vocab_size: int - num_key_value_heads: Optional[int] = None + num_key_value_heads: int mup_attn_multiplier: float = 1.0 mup_use_scaling: bool = True mup_embedding_multiplier: float = 10.0 mup_width_multiplier: float = 8.0 rope_embedding_base: float = 1000000 rope_position_scale: float = 1.0 - blocksparse_block_size: Tuple[int] = (64,) + blocksparse_block_size: int = 64 blocksparse_num_local_blocks: int = 16 blocksparse_vert_stride: int = 8 @@ -61,7 +61,6 @@ class Attention(nn.Module): dim = args.hidden_size self.n_heads = n_heads = args.num_attention_heads - assert args.num_key_value_heads is not None self.n_kv_heads = n_kv_heads = args.num_key_value_heads self.n_q_per_kv = n_heads // n_kv_heads @@ -161,7 +160,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -230,7 +229,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -304,16 +303,8 @@ class Model(nn.Module): def layers(self): return self.model.layers - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - def sanitize(self, weights): # Remove unused precomputed rotary freqs return { k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k } - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phimoe.py b/llms/mlx_lm/models/phimoe.py index db6bd4b5..ca20a388 100644 --- a/llms/mlx_lm/models/phimoe.py +++ b/llms/mlx_lm/models/phimoe.py @@ -173,6 +173,7 @@ class PhiMoEModel(nn.Module): class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() + self.model_type = args.model_type self.args = args self.model = PhiMoEModel(args) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True) @@ -208,11 +209,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index bb67615d..865d0d8e 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -168,8 +168,8 @@ class Model(nn.Module): self, x: mx.array, mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: mask = create_attention_mask(x, cache) y = self.transformer(x, mask, cache) @@ -193,11 +193,3 @@ class Model(nn.Module): @property def layers(self): return self.transformer.h - - @property - def head_dim(self): - return self.args.model_dim // self.args.num_heads - - @property - def n_kv_heads(self): - return self.args.num_heads diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index 5d2b7586..090922ae 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn @@ -62,8 +62,8 @@ class Attention(nn.Module): self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> Tuple[mx.array, Tuple[mx.array, mx.array]]: + cache: Optional[Any] = None, + ) -> mx.array: bsz, q_len, _ = hidden_states.shape queries = self.q_proj(hidden_states) @@ -127,8 +127,8 @@ class PlamoDecoderLayer(nn.Module): self, hidden_states: mx.array, attention_mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> Tuple[Any, ...]: + cache: Optional[Any] = None, + ): # from LlamaDecoder residual = hidden_states @@ -169,8 +169,8 @@ class PlamoModel(nn.Module): def __call__( self, inputs: mx.array, - cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None, - ) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]: + cache: Optional[Any] = None, + ) -> mx.array: h = self.embed_tokens(inputs) mask = create_attention_mask(h, cache) @@ -197,19 +197,11 @@ class Model(nn.Module): def __call__( self, inputs: mx.array, - cache: Optional[List[Tuple[mx.array, mx.array]]] = None, - ) -> Tuple[mx.array, mx.array]: + cache: Optional[Any] = None, + ) -> mx.array: out = self.model(inputs, cache) return self.lm_head(out) @property def layers(self): return self.model.layers.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_attention_heads // self.args.n_shared_head diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 6d2c7bbf..2b69d5ec 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -1,7 +1,6 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Tuple import mlx.core as mx import mlx.nn as nn @@ -149,19 +148,11 @@ class Model(nn.Module): self, x: mx.array, mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: y = self.transformer(x, mask, cache) return self.lm_head(y) @property def layers(self): return self.transformer.h - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_attention_heads diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index b3ce02a3..4e7858de 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -70,7 +70,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -124,7 +124,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -196,11 +196,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index ff7831f3..d199116f 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -2,12 +2,12 @@ import math from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Union import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask from .switch_layers import SwitchGLU @@ -70,7 +70,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -162,7 +162,7 @@ class Qwen2MoeDecoderLayer(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -236,11 +236,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 34750ace..b861b286 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -7,13 +7,13 @@ from typing import List, Literal, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs +from .base import BaseModelArgs, create_attention_mask +from .cache import MambaCache, RotatingKVCache @dataclass class ModelArgs(BaseModelArgs): model_type: str - hidden_size: int attention_bias: bool conv1d_width: int hidden_size: int @@ -36,59 +36,6 @@ class ModelArgs(BaseModelArgs): self.block_types = self._block_types -def create_window_causal_mask(N: int, window_size: int): - inds = mx.arange(N) - linds = inds[:, None] - rinds = inds[None] - mask = (linds < rinds) | (linds > rinds + window_size) - return mask * -1e9 - - -class RecurrentCache: - - def __init__(self): - self._cache = (None, None) - - def __getitem__(self, idx): - return self._cache[idx] - - def update(self, conv_state, recurrent_state): - self._cache = (conv_state, recurrent_state) - - def state(self): - return self._cache - - -class WindowKVCache: - - def __init__(self, window_size): - self.keys = None - self.values = None - self.offset = 0 - self.window_size = window_size - - def update_and_fetch(self, keys, values): - # TODO consider using rotating buffer here - # especially for very long generations - def _update(x, v): - t = x.shape[2] - self.window_size - if t > 0: - x = x[..., t:, :] - return mx.concatenate([x, v], axis=2) - - self.offset += keys.shape[2] - if self.keys is None: - self.keys = keys - self.values = values - else: - self.keys = _update(self.keys, keys) - self.values = _update(self.values, values) - return self.keys, self.values - - def state(self): - return self.keys, self.values - - class RMSNorm(nn.Module): def __init__(self, dims: int, eps: float = 1e-5): super().__init__() @@ -136,31 +83,22 @@ class Conv1d(nn.Module): kernel_size: int, ): super().__init__() - self.weight = mx.zeros((kernel_size, channels)) + self.weight = mx.zeros((channels, kernel_size, 1)) self.bias = mx.zeros((channels,)) def __call__(self, x, cache=None): - w = self.weight.T[..., None] - kw, groups = self.weight.shape - if cache is not None: - l = [] - # Pad the cache if needed - if cache.shape[1] < kw - 1: - l.append( - mx.zeros( - (x.shape[0], kw - 1 - cache.shape[1], groups), dtype=x.dtype - ) - ) - l.extend([cache, x]) - x = mx.concatenate(l, axis=1) - y = (x * w.swapaxes(0, 2)).sum(axis=1, keepdims=True) - else: - y = mx.conv_general(x, w, padding=([kw - 1], [0]), groups=groups) + B, L, C = x.shape + groups, K, _ = self.weight.shape - # The cache is always kw - 1 - cache = x[:, max(x.shape[1] - kw + 1, 0) :, :] + if cache is not None: + x = mx.concatenate([cache, x], axis=1) + else: + x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) + + y = mx.conv_general(x, self.weight, groups=groups) y = y + self.bias - return y, cache + + return y, x[:, -K + 1 :, :] class RGLRU(nn.Module): @@ -269,19 +207,9 @@ class RecurrentBlock(nn.Module): # x branch. x = self.linear_x(x) if cache is None: - conv_state, recurrent_state = (None, None) - else: - conv_state, recurrent_state = cache[0], cache[1] - x, conv_state = self.conv_1d( - x=x, - cache=conv_state, - ) - x, recurrent_state = self.rg_lru( - x=x, - cache=recurrent_state, - ) - if cache is not None: - cache.update(conv_state, recurrent_state) + cache = [None, None] + x, cache[0] = self.conv_1d(x=x, cache=cache[0]) + x, cache[1] = self.rg_lru(x=x, cache=cache[1]) x = x * y x = self.linear_out(x) @@ -467,12 +395,10 @@ class Griffin(nn.Module): if self.scale_by_sqrt_dim: x = x * math.sqrt(x.shape[-1]) - mask = None - if x.shape[1] > 1: - mask = create_window_causal_mask( - x.shape[1], self.config.attention_window_size - ) - mask = mask.astype(x.dtype) + if cache is None: + cache = [None] * len(self.layers) + + mask = create_attention_mask(x, cache) for i, block in enumerate(self.layers): x = block(x, mask=mask, cache=cache[i]) @@ -485,6 +411,7 @@ class Model(nn.Module): def __init__(self, config): self.args = config self.model = Griffin(config) + self.model_type = config.model_type self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def __call__(self, tokens: mx.array, cache=None) -> mx.array: @@ -508,10 +435,9 @@ class Model(nn.Module): return self.model.layers def sanitize(self, weights): - # Remove unused precomputed rotary freqs for k, v in weights.items(): if "conv_1d.weight" in k and v.ndim == 3: - weights[k] = v.squeeze(1).T + weights[k] = v.moveaxis(2, 1) if "lm_head.weight" not in weights: self.pop("lm_head") return weights @@ -520,7 +446,7 @@ class Model(nn.Module): cache = [] for layer in self.layers: if layer.temporal_block_type == "recurrent": - cache.append(RecurrentCache()) + cache.append(MambaCache()) else: - cache.append(WindowKVCache(self.args.attention_window_size)) + cache.append(RotatingKVCache(max_size=self.args.attention_window_size)) return cache diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index b340de28..11202b02 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -2,7 +2,6 @@ import math from dataclasses import dataclass -from typing import Tuple import mlx.core as mx import mlx.nn as nn @@ -198,8 +197,8 @@ class Model(nn.Module): self, x: mx.array, mask: mx.array = None, - cache: mx.array = None, - ) -> Tuple[mx.array, mx.array]: + cache=None, + ) -> mx.array: mask = create_attention_mask(x, cache) y = self.model(x, mask, cache) return self.lm_head(y) @@ -207,11 +206,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index 9cec0e39..ce0a2ec5 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -1,12 +1,12 @@ # Copyright © 2023-2024 Apple Inc. from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Any, Optional import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, KVCache, create_attention_mask +from .base import BaseModelArgs, create_attention_mask @dataclass @@ -45,7 +45,7 @@ class Attention(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: B, L, D = x.shape @@ -100,7 +100,7 @@ class TransformerBlock(nn.Module): self, x: mx.array, mask: Optional[mx.array] = None, - cache: Optional[KVCache] = None, + cache: Optional[Any] = None, ) -> mx.array: r = self.self_attn(self.input_layernorm(x), mask, cache) h = x + r @@ -164,11 +164,3 @@ class Model(nn.Module): @property def layers(self): return self.model.layers - - @property - def head_dim(self): - return self.args.hidden_size // self.args.num_attention_heads - - @property - def n_kv_heads(self): - return self.args.num_key_value_heads diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 54a96457..7017b9d3 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -18,7 +18,7 @@ from mlx.utils import tree_flatten from transformers import PreTrainedTokenizer # Local imports -from .models.base import KVCache, RotatingKVCache +from .models import base, cache from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import dequantize as dequantize_model @@ -124,26 +124,6 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float) return logits -def make_kv_caches( - model: nn.Module, max_kv_size: Optional[int] = None -) -> List[Union[KVCache, RotatingKVCache]]: - if hasattr(model, "make_cache"): - return model.make_cache() - - kv_heads = ( - [model.n_kv_heads] * len(model.layers) - if isinstance(model.n_kv_heads, int) - else model.n_kv_heads - ) - if max_kv_size is not None: - return [ - RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4) - for n in kv_heads - ] - else: - return [KVCache(model.head_dim, n) for n in kv_heads] - - def generate_step( prompt: mx.array, model: nn.Module, @@ -155,7 +135,7 @@ def generate_step( min_tokens_to_keep: int = 1, prefill_step_size: int = 512, max_kv_size: Optional[int] = None, - cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None, + prompt_cache: Optional[Any] = None, logit_bias: Optional[Dict[int, float]] = None, logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: @@ -180,6 +160,8 @@ def generate_step( prefill_step_size (int): Step size for processing the prompt. max_kv_size (int, optional): Maximum size of the key-value cache. Old entries (except the first 4 tokens) will be overwritten. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. logit_bias (dictionary, optional): Additive logit bias. logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional): A list of functions that take tokens and logits and return the processed @@ -237,20 +219,13 @@ def generate_step( tokens = None # Create the KV cache for generation - cache = make_kv_caches(model, max_kv_size) - - if cache_history is not None: - if len(cache_history) != len(cache): - raise ValueError("Wrong number of layers in the cache history") - - # Set the history in the cache objects and evaluate them to prepare for - # generation. - for c, h in zip(cache, cache_history): - c.update_and_fetch(h[0], h[1]) - mx.eval([c.state for c in cache]) + if prompt_cache is None: + prompt_cache = cache.make_prompt_cache(model, max_kv_size) + elif len(prompt_cache) != len(model.layers): + raise ValueError("Wrong number of layers in the prompt cache.") def _step(y): - logits = model(y[None], cache=cache) + logits = model(y[None], cache=prompt_cache) logits = logits[:, -1, :] if logits_processor: @@ -265,7 +240,7 @@ def generate_step( while y.size > prefill_step_size: model(y[:prefill_step_size][None], cache=cache) - mx.eval([c.state for c in cache]) + mx.eval([c.state[0] for c in cache]) y = y[prefill_step_size:] y, logprobs = _step(y) @@ -305,9 +280,9 @@ def stream_generate( detokenizer = tokenizer.detokenizer detokenizer.reset() - for (token, _), n in zip( - generate_step(prompt_tokens, model, **kwargs), + for n, (token, _) in zip( range(max_tokens), + generate_step(prompt_tokens, model, **kwargs), ): if token == tokenizer.eos_token_id: break @@ -357,9 +332,9 @@ def generate( tic = time.perf_counter() detokenizer.reset() - for (token, logprobs), n in zip( - generate_step(prompt_tokens, model, **kwargs), + for n, (token, logprobs) in zip( range(max_tokens), + generate_step(prompt_tokens, model, **kwargs), ): if n == 0: prompt_time = time.perf_counter() - tic diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index cb676a47..a839f797 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -1,5 +1,4 @@ # Copyright © 2024 Apple Inc. - import unittest import mlx.core as mx @@ -11,7 +10,7 @@ from mlx_lm.utils import make_kv_caches class TestModels(unittest.TestCase): def test_kv_cache(self): - cache = KVCache(32, 4) + cache = KVCache() k = mx.ones((1, 4, 1, 32), mx.float16) v = mx.ones((1, 4, 1, 32), mx.float16) @@ -32,7 +31,7 @@ class TestModels(unittest.TestCase): def test_rotating_kv_cache(self): b, h, d = 1, 2, 32 - cache = RotatingKVCache(d, h, max_size=8, step=4) + cache = RotatingKVCache(max_size=8, step=4) k = mx.random.uniform(shape=(b, h, 2, d)) v = mx.random.uniform(shape=(b, h, 2, d)) @@ -65,7 +64,7 @@ class TestModels(unittest.TestCase): idx %= 8 # Try with nonzero keep - cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2) + cache = RotatingKVCache(max_size=8, step=4, keep=2) # Check a large update k = mx.random.uniform(shape=(b, h, 20, d)) @@ -93,7 +92,7 @@ class TestModels(unittest.TestCase): # alternating prompt/prefill with generation d = 4 h = 2 - cache = RotatingKVCache(d, h, max_size=18, step=4) + cache = RotatingKVCache(max_size=18, step=4) x = mx.random.uniform(shape=(1, h, 8, d)) k, v = cache.update_and_fetch(x, x) @@ -589,6 +588,179 @@ class TestModels(unittest.TestCase): model, args.model_type, args.vocab_size, args.num_hidden_layers ) + def test_deepseek(self): + from mlx_lm.models import deepseek + + args = deepseek.ModelArgs( + model_type="deepseek", + vocab_size=1024, + hidden_size=128, + intermediate_size=256, + moe_intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=4, + ) + model = deepseek.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_deepseek_v2(self): + from mlx_lm.models import deepseek_v2 + + args = deepseek_v2.ModelArgs( + model_type="deepseek_v2", + vocab_size=1024, + hidden_size=128, + intermediate_size=256, + moe_intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + kv_lora_rank=4, + q_lora_rank=4, + qk_rope_head_dim=32, + v_head_dim=16, + qk_nope_head_dim=32, + rope_scaling={ + "beta_fast": 32, + "beta_slow": 1, + "factor": 40, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn", + }, + ) + model = deepseek_v2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_gemma2(self): + from mlx_lm.models import gemma2 + + args = gemma2.ModelArgs( + model_type="gemma2", + hidden_size=128, + num_hidden_layers=4, + intermediate_size=256, + num_attention_heads=2, + head_dim=32, + rms_norm_eps=1e-4, + vocab_size=1024, + num_key_value_heads=2, + ) + model = gemma2.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_gpt_bigcode(self): + from mlx_lm.models import gpt_bigcode + + args = gpt_bigcode.ModelArgs( + model_type="gpt_bigcode", + n_embd=128, + n_layer=128, + n_inner=256, + n_head=4, + n_positions=1000, + layer_norm_epsilon=1e-5, + vocab_size=1024, + ) + model = gpt_bigcode.Model(args) + self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer) + + def test_nemotron(self): + from mlx_lm.models import nemotron + + args = nemotron.ModelArgs( + model_type="nemotron", + hidden_size=128, + hidden_act="gelu", + num_hidden_layers=4, + intermediate_size=256, + num_attention_heads=4, + norm_eps=1e-5, + vocab_size=1024, + num_key_value_heads=2, + ) + model = nemotron.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_phi3small(self): + from mlx_lm.models import phi3small + + args = phi3small.ModelArgs( + model_type="phi3small", + hidden_size=128, + dense_attention_every_n_layers=2, + ff_intermediate_size=256, + gegelu_limit=1.0, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=2, + layer_norm_epsilon=1e-4, + vocab_size=1000, + ) + model = phi3small.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_phimoe(self): + from mlx_lm.models import phimoe + + args = phimoe.ModelArgs( + model_type="phimoe", + vocab_size=320, + hidden_size=128, + intermediate_size=256, + num_hidden_layers=4, + num_attention_heads=4, + num_key_value_heads=4, + rope_scaling={ + "long_factor": [1.0] * 16, + "long_mscale": 1.243163121016122, + "original_max_position_embeddings": 4096, + "short_factor": [1.0] * 16, + "short_mscale": 1.243163121016122, + "type": "longrope", + }, + ) + model = phimoe.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + + def test_recurrent_gemma(self): + from mlx_lm.models import recurrent_gemma + + args = recurrent_gemma.ModelArgs( + model_type="recurrent_gemma", + hidden_size=128, + attention_bias=False, + conv1d_width=3, + intermediate_size=256, + logits_soft_cap=1.0, + num_attention_heads=4, + num_hidden_layers=4, + num_key_value_heads=2, + rms_norm_eps=1e-4, + rope_theta=1000, + attention_window_size=1024, + vocab_size=1000, + block_types=["recurrent", "recurrent", "attention"], + ) + model = recurrent_gemma.Model(args) + self.model_test_runner( + model, args.model_type, args.vocab_size, args.num_hidden_layers + ) + if __name__ == "__main__": unittest.main() diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py new file mode 100644 index 00000000..14b5e961 --- /dev/null +++ b/llms/tests/test_prompt_cache.py @@ -0,0 +1,143 @@ +# Copyright © 2024 Apple Inc. + +import os +import tempfile +import unittest + +import mlx.core as mx +from mlx_lm.models.cache import ( + KVCache, + MambaCache, + RotatingKVCache, + load_prompt_cache, + make_prompt_cache, + save_prompt_cache, +) +from mlx_lm.utils import generate_step, load + +HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" + + +class TestPromptCache(unittest.TestCase): + + @classmethod + def setUpClass(cls): + cls.test_dir_fid = tempfile.TemporaryDirectory() + cls.test_dir = cls.test_dir_fid.name + + @classmethod + def tearDownClass(cls): + cls.test_dir_fid.cleanup() + + def test_save_load(self): + cache = [KVCache() for _ in range(4)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 4)) + c.update_and_fetch(x, x) + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + self.assertTrue(len(cache), len(loaded_cache)) + for c, lc in zip(cache, loaded_cache): + self.assertEqual(c.offset, lc.offset) + self.assertTrue(mx.array_equal(c.state[0][0], lc.state[0][0])) + self.assertTrue(mx.array_equal(c.state[0][1], lc.state[0][1])) + + # Test with metadata + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + metadata = {"a": "b", "c": "d"} + save_prompt_cache(cache_file, cache, metadata) + _, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True) + self.assertEqual(metadata, loaded_metadata) + + def test_save_load_rotating_cache(self): + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + + # Test with rotating cache + cache = [RotatingKVCache(max_size=8, keep=2) for _ in range(4)] + for c in cache: + x = mx.random.uniform(shape=(1, 8, 10, 4)) + c.update_and_fetch(x, x) + + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + self.assertTrue(len(cache), len(loaded_cache)) + for c, lc in zip(cache, loaded_cache): + self.assertEqual(c.offset, lc.offset) + self.assertEqual(c.keep, lc.keep) + self.assertEqual(c.max_size, lc.max_size) + self.assertEqual(c.step, lc.step) + self.assertTrue(mx.array_equal(c.state[0][0], lc.state[0][0])) + self.assertTrue(mx.array_equal(c.state[0][1], lc.state[0][1])) + + # Do a couple single token updates to get a rotation + for _ in range(2): + for c in cache: + x = mx.random.uniform(shape=(1, 8, 1, 4)) + c.update_and_fetch(x, x) + + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + + for c, lc in zip(cache, loaded_cache): + x = mx.random.uniform(shape=(1, 8, 1, 4)) + k, v = c.update_and_fetch(x, x) + lk, lv = lc.update_and_fetch(x, x) + self.assertEqual(c.offset, lc.offset) + self.assertTrue(mx.array_equal(k, lk)) + self.assertTrue(mx.array_equal(v, lv)) + + def test_save_load_mixed_cache(self): + cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors") + + cache = [MambaCache(), KVCache(), RotatingKVCache(8), MambaCache()] + for c in cache: + if isinstance(c, MambaCache): + c[0] = mx.random.uniform(shape=(4, 4, 4)) + c[1] = mx.random.uniform(shape=(4, 4, 4)) + else: + x = mx.random.uniform(shape=(4, 4, 7, 4)) + y = mx.random.uniform(shape=(4, 4, 7, 4)) + c.update_and_fetch(x, y) + + save_prompt_cache(cache_file, cache) + loaded_cache = load_prompt_cache(cache_file) + for c, lc in zip(cache, loaded_cache): + if isinstance(c, MambaCache): + self.assertTrue(mx.array_equal(c[0], lc[0])) + self.assertTrue(mx.array_equal(c[1], lc[1])) + else: + x = mx.random.uniform(shape=(4, 4, 1, 4)) + y = mx.random.uniform(shape=(4, 4, 1, 4)) + k, v = c.update_and_fetch(x, y) + lk, lv = lc.update_and_fetch(x, y) + self.assertEqual(c.offset, lc.offset) + self.assertTrue(mx.array_equal(k, lk)) + self.assertTrue(mx.array_equal(v, lv)) + + def test_cache_with_generate(self): + model, tokenizer = load(HF_MODEL_PATH) + prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] + results = zip(range(4), generate_step(prompt, model)) + toks, all_logits = zip(*(r[1] for r in results)) + + prompt_cache = make_prompt_cache(model) + i = 0 + for _, (tok, logits) in zip( + range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + ): + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i])) + i += 1 + + for _, (tok, logits) in zip( + range(1), + generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + ): + i += 1 + self.assertEqual(tok, toks[i]) + self.assertTrue(mx.allclose(logits, all_logits[i])) + + +if __name__ == "__main__": + unittest.main()