Merge branch 'adding-support-for-mamba2' of https://github.com/Goekdeniz-Guelmez/mlx-examples into adding-support-for-mamba2

This commit is contained in:
Goekdeniz-Guelmez
2024-10-16 21:09:42 +02:00
90 changed files with 4951 additions and 998 deletions

View File

@@ -20,6 +20,31 @@ The `mlx-lm` package also has:
- [Merging models](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/MERGE.md)
- [HTTP model serving](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/SERVER.md)
### Quick Start
To generate text with an LLM use:
```bash
mlx_lm.generate --prompt "Hi!"
```
To chat with an LLM use:
```bash
mlx_lm.chat
```
This will give you a chat REPL that you can use to interact with the LLM. The
chat context is preserved during the lifetime of the REPL.
Commands in `mlx-lm` typically take command line options which let you specify
the model, sampling parameters, and more. Use `-h` to see a list of available
options for a command, e.g.:
```bash
mlx_lm.generate -h
```
### Python API
You can use `mlx-lm` as a module:
@@ -138,7 +163,7 @@ mlx_lm.convert \
### Long Prompts and Generations
MLX LM has some tools to scale efficiently to long prompts and generations:
`mlx-lm` has some tools to scale efficiently to long prompts and generations:
- A rotating fixed-size key-value cache.
- Prompt caching
@@ -155,14 +180,14 @@ different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
cat prompt.txt | mlx_lm.cache_prompt \
--model mistralai/Mistral-7B-Instruct-v0.3 \
--prompt - \
--kv-cache-file mistral_prompt.safetensors
--prompt-cache-file mistral_prompt.safetensors
```
Then use the cached prompt with `mlx_lm.generate`:
```
mlx_lm.generate \
--kv-cache-file mistral_prompt.safetensors \
--prompt-cache-file mistral_prompt.safetensors \
--prompt "\nSummarize the above text."
```
@@ -170,9 +195,15 @@ The cached prompt is treated as a prefix to the supplied prompt. Also notice
when using a cached prompt, the model to use is read from the cache and need
not be supplied explicitly.
Prompt caching can also be used in the Python API in order to to avoid
recomputing the prompt. This is useful in multi-turn dialogues or across
requests that use the same context. See the
[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
for more usage details.
### Supported Models
MLX LM supports thousands of Hugging Face format LLMs. If the model you want to
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
run is not supported, file an
[issue](https://github.com/ml-explore/mlx-examples/issues/new) or better yet,
submit a pull request.

View File

@@ -50,7 +50,7 @@ curl localhost:8080/v1/chat/completions \
- `role_mapping`: (Optional) A dictionary to customize the role prefixes in
the generated prompt. If not provided, the default mappings are used.
- `stop`: (Optional) An array of strings or a single string. Thesse are
- `stop`: (Optional) An array of strings or a single string. These are
sequences of tokens on which the generation should stop.
- `max_tokens`: (Optional) An integer specifying the maximum number of tokens
@@ -84,7 +84,37 @@ curl localhost:8080/v1/chat/completions \
started in.
- `adapters`: (Optional) A string path to low-rank adapters. The path must be
rlative to the directory the server was started in.
relative to the directory the server was started in.
### Response Fields
- `id`: A unique identifier for the chat.
- `system_fingerprint`: A unique identifier for the system.
- `object`: Any of "chat.completions", "chat.completions.chunk" (for
streaming), or "text.completion".
- `model`: The model repo or path (e.g. `"mlx-community/Llama-3.2-3B-Instruct-4bit"`).
- `created`: A time-stamp for when the request was processed.
- `choices`: A list of outputs. Each output is a dictionary containing the fields:
- `index`: The index in the list.
- `logprobs`: A dictionary containing the fields:
- `token_logprobs`: A list of the log probabilities for the generated
tokens.
- `tokens`: A list of the generated token ids.
- `top_logprobs`: A list of lists. Each list contains the `logprobs`
top tokens (if requested) with their corresponding probabilities.
- `finish_reason`: The reason the completion ended. This can be either of
`"stop"` or `"length"`.
- `message`: The text response from the model.
- `usage`: A dictionary containing the fields:
- `prompt_tokens`: The number of prompt tokens processed.
- `completion_tokens`: The number of tokens generated.
- `total_tokens`: The total number of tokens, i.e. the sum of the above two fields.
### List Models
@@ -97,5 +127,5 @@ curl localhost:8080/v1/models -H "Content-Type: application/json"
This will return a list of locally available models where each model in the
list contains the following fields:
- `"id"`: The Hugging Face repo id.
- `"created"`: A timestamp representing the model creation time.
- `id`: The Hugging Face repo id.
- `created`: A time-stamp representing the model creation time.

View File

@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.18.2"
__version__ = "0.19.1"

View File

@@ -7,13 +7,14 @@ import time
import mlx.core as mx
from .utils import load, make_kv_caches
from .models.cache import make_prompt_cache, save_prompt_cache
from .utils import load
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(
description="Cache the KV cache of a prompt to be reused with mlx_lm.generate"
description="Cache the state of a prompt to be reused with mlx_lm.generate"
)
parser.add_argument(
"--model",
@@ -60,7 +61,9 @@ def setup_arg_parser():
help="Set the maximum key-value cache size",
)
parser.add_argument(
"--kv-cache-file", help="The file to save the KV caches in", required=True
"--prompt-cache-file",
help="The file to save the prompt cache in",
required=True,
)
parser.add_argument(
"--prompt",
@@ -115,7 +118,7 @@ def main():
else:
prompt = args.prompt
cache = make_kv_caches(model, args.max_kv_size)
cache = make_prompt_cache(model, args.max_kv_size)
y = mx.array(tokenizer.encode(prompt))
# Process the prompt
@@ -137,16 +140,12 @@ def main():
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
print("Saving...")
cache_dict = {}
for i, c in enumerate(cache):
cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :]
cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :]
metadata = {}
metadata["model"] = args.model
metadata["chat_template"] = tokenizer.chat_template
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
metadata["max_kv_size"] = str(args.max_kv_size)
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
save_prompt_cache(args.prompt_cache_file, cache, metadata)
if __name__ == "__main__":

82
llms/mlx_lm/chat.py Normal file
View File

@@ -0,0 +1,82 @@
# Copyright © 2023-2024 Apple Inc.
import argparse
import json
import mlx.core as mx
from .models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
from .utils import load, stream_generate
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def setup_arg_parser():
"""Set up and return the argument parser."""
parser = argparse.ArgumentParser(description="Chat with an LLM")
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
default=DEFAULT_MODEL,
)
parser.add_argument(
"--adapter-path",
type=str,
help="Optional path for the trained adapter weights and config.",
)
parser.add_argument(
"--temp", type=float, default=DEFAULT_TEMP, help="Sampling temperature"
)
parser.add_argument(
"--top-p", type=float, default=DEFAULT_TOP_P, help="Sampling top-p"
)
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="PRNG seed")
parser.add_argument(
"--max-kv-size",
type=int,
help="Set the maximum key-value cache size",
default=None,
)
return parser
def main():
parser = setup_arg_parser()
args = parser.parse_args()
mx.random.seed(args.seed)
model, tokenizer = load(
args.model,
adapter_path=args.adapter_path,
tokenizer_config={"trust_remote_code": True},
)
print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
query = input(">> ")
if query == "q":
break
messages = [{"role": "user", "content": query}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
for response in stream_generate(
model,
tokenizer,
prompt,
temp=args.temp,
top_p=args.top_p,
prompt_cache=prompt_cache,
):
print(response, flush=True, end="")
print()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,53 @@
# Copyright © 2024 Apple Inc.
"""
An example of a multi-turn chat with prompt caching.
"""
from mlx_lm import generate, load
from mlx_lm.models.cache import load_prompt_cache, make_prompt_cache, save_prompt_cache
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
# Make the initial prompt cache for the model
prompt_cache = make_prompt_cache(model)
# User turn
prompt = "Hi my name is <Name>."
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response
response = generate(
model,
tokenizer,
prompt=prompt,
verbose=True,
temp=0.0,
prompt_cache=prompt_cache,
)
# User turn
prompt = "What's my name?"
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Assistant response
response = generate(
model,
tokenizer,
prompt=prompt,
verbose=True,
temp=0.0,
prompt_cache=prompt_cache,
)
# Save the prompt cache to disk to reuse it at a later time
save_prompt_cache("mistral_prompt.safetensors", prompt_cache)
# Load the prompt cache from disk
prompt_cache = load_prompt_cache("mistral_prompt.safetensors")

View File

@@ -1,3 +1,5 @@
# Copyright © 2024 Apple Inc.
from mlx_lm import generate, load
# Specify the checkpoint

View File

@@ -6,13 +6,15 @@ import sys
import mlx.core as mx
from .models.cache import load_prompt_cache
from .utils import generate, load
DEFAULT_PROMPT = "hello"
DEFAULT_MAX_TOKENS = 100
DEFAULT_TEMP = 0.6
DEFAULT_TEMP = 0.0
DEFAULT_TOP_P = 1.0
DEFAULT_SEED = 0
DEFAULT_MODEL = "mlx-community/Llama-3.2-3B-Instruct-4bit"
def str2bool(string):
@@ -25,7 +27,11 @@ def setup_arg_parser():
parser.add_argument(
"--model",
type=str,
help="The path to the local model directory or Hugging Face repo.",
help=(
"The path to the local model directory or Hugging Face repo. "
f"If no model is specified, then {DEFAULT_MODEL} is used."
),
default=None,
)
parser.add_argument(
"--adapter-path",
@@ -96,7 +102,7 @@ def setup_arg_parser():
default=None,
)
parser.add_argument(
"--kv-cache-file",
"--prompt-cache-file",
type=str,
default=None,
help="A file containing saved KV caches to avoid recomputing them",
@@ -131,24 +137,6 @@ def colorprint_by_t0(s, t0):
colorprint(color, s)
def load_kv_cache_from_file(kv_cache_file):
if kv_cache_file is None:
return None, None
kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True)
cache_per_layer = {}
for k, x in kv_cache.items():
layer, kv_type = k.split("_")
if layer not in cache_per_layer:
cache_per_layer[layer] = {}
cache_per_layer[layer][kv_type] = x
cache_history = [None] * len(cache_per_layer)
for layer, c in cache_per_layer.items():
cache_history[int(layer)] = (c["keys"], c["values"])
return cache_history, metadata
def main():
parser = setup_arg_parser()
args = parser.parse_args()
@@ -158,22 +146,33 @@ def main():
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Load the kv cache and metadata if a kv cache file is provided
cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file)
# Load the prompt cache and metadata if a cache file is provided
using_cache = args.prompt_cache_file is not None
if using_cache:
prompt_cache, metadata = load_prompt_cache(
args.prompt_cache_file, return_metadata=True
)
# Building tokenizer_config
tokenizer_config = (
{} if cache_history is None else json.loads(metadata["tokenizer_config"])
{} if not using_cache else json.loads(metadata["tokenizer_config"])
)
if args.trust_remote_code:
tokenizer_config["trust_remote_code"] = True
if args.eos_token is not None:
tokenizer_config["eos_token"] = args.eos_token
# If no model path is provided then use the one in the kv cache history
model_path = args.model
if cache_history is not None and model_path is None:
model_path = metadata["model"]
if using_cache:
if model_path is None:
model_path = metadata["model"]
elif model_path != metadata["model"]:
raise ValueError(
f"Providing a different model ({model_path}) than that "
f"used to create the prompt cache ({metadata['model']}) "
"is an error."
)
model_path = model_path or DEFAULT_MODEL
model, tokenizer = load(
model_path,
@@ -184,7 +183,7 @@ def main():
if args.use_default_chat_template:
if tokenizer.chat_template is None:
tokenizer.chat_template = tokenizer.default_chat_template
elif cache_history is not None:
elif using_cache:
tokenizer.chat_template = metadata["chat_template"]
if not args.ignore_chat_template and (
@@ -203,7 +202,7 @@ def main():
# Treat the prompt as a suffix assuming that the prefix is in the
# stored kv cache.
if cache_history is not None:
if using_cache:
test_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": "<query>"}],
tokenize=False,
@@ -217,12 +216,6 @@ def main():
raise ValueError("Cannot use --colorize with --verbose=False")
formatter = colorprint_by_t0 if args.colorize else None
# Determine the max kv size from the kv cache or passed arguments
max_kv_size = args.max_kv_size
if cache_history is not None:
max_kv_size = metadata["max_kv_size"]
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
response = generate(
model,
tokenizer,
@@ -232,8 +225,8 @@ def main():
formatter=formatter,
temp=args.temp,
top_p=args.top_p,
max_kv_size=max_kv_size,
cache_history=cache_history,
max_kv_size=args.max_kv_size,
prompt_cache=prompt_cache if using_cache else None,
)
if not args.verbose:
print(response)

View File

@@ -2,145 +2,9 @@
import inspect
from dataclasses import dataclass
from typing import Any, List, Optional
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
class KVCache:
def __init__(self, head_dim, n_kv_heads):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keys = None
self.values = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B = keys.shape[0]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@property
def state(self):
return self.keys, self.values
class RotatingKVCache:
def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keep = keep
self.keys = None
self.values = None
self.offset = 0
self.max_size = max_size
self.step = step
self._idx = 0
def _trim(self, trim_size, v, append=None):
to_cat = []
if trim_size > 0:
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
else:
to_cat = [v]
if append is not None:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def update_and_fetch(self, keys, values):
prev = self.offset
B, _, S = keys.shape[:3]
# Prefill mode
if S > 1:
if self.keys is None:
self.keys = keys
self.values = values
else:
# The largest size is self.max_size + S - 1 to ensure
# every token gets at least self.max_size context
trim_size = self.keys.shape[2] - self.max_size + 1
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += S
self._idx = self.keys.shape[2]
return self.keys, self.values
# Generation mode
# May not have hit the max size yet, so potentially
# keep growing the cache
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
):
new_size = min(self.step, self.max_size - prev)
k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim)
v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self._idx = prev
# Trim if needed
trim_size = self.keys.shape[2] - self.max_size
if trim_size > 0:
self.keys = self._trim(trim_size, self.keys)
self.values = self._trim(trim_size, self.values)
self._idx = self.max_size
# Rotate
if self._idx == self.max_size:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + 1, :] = keys
self.values[..., self._idx : self._idx + 1, :] = values
self.offset += 1
self._idx += 1
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
@property
def state(self):
return self.keys, self.values
@dataclass
@@ -156,25 +20,30 @@ class BaseModelArgs:
)
def create_additive_causal_mask(N: int, offset: int = 0):
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
mask = linds[:, None] < rinds[None]
linds = linds[:, None]
rinds = rinds[None]
mask = linds < rinds
if window_size is not None:
mask = mask | (linds > rinds + window_size)
return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
T = h.shape[1]
if T > 1:
window_size = None
offset = 0
if cache is not None and cache[0] is not None:
c = cache[0]
if isinstance(c, RotatingKVCache):
if hasattr(c, "max_size"):
offset = min(c.max_size - 1, c.offset)
window_size = c.max_size
else:
offset = c.offset
else:
offset = 0
mask = create_additive_causal_mask(T, offset)
mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else:
mask = None

340
llms/mlx_lm/models/cache.py Normal file
View File

@@ -0,0 +1,340 @@
# Copyright © 2023-2024 Apple Inc.
from typing import Any, Dict, List, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
"""
Construct the model's cache for use when cgeneration.
This function will defer the cache construction to the model if it has a
``make_cache`` method, otherwise it will make a default KV cache.
Args:
model (nn.Module): The language model.
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
"""
if hasattr(model, "make_cache"):
return model.make_cache()
num_layers = len(model.layers)
if max_kv_size is not None:
return [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
]
else:
return [KVCache() for _ in range(num_layers)]
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
"""
Save a pre-computed prompt cache to a file.
Args:
file_name (str): The ``.safetensors`` file name.
cache (List[Any]): The model state.
metadata (Dict[str, str]): Optional metadata to save along with model
state.
"""
cache_data = [c.state for c in cache]
cache_info = [c.meta_state for c in cache]
cache_data = dict(tree_flatten(cache_data))
cache_classes = [type(c).__name__ for c in cache]
cache_metadata = [cache_info, metadata, cache_classes]
cache_metadata = dict(tree_flatten(cache_metadata))
mx.save_safetensors(file_name, cache_data, cache_metadata)
def load_prompt_cache(file_name, return_metadata=False):
"""
Load a prompt cache from a file.
Args:
file_name (str): The ``.safetensors`` file name.
return_metadata (bool): Whether or not to return metadata.
Default: ``False``.
Returns:
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
the metadata if requested.
"""
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
arrays = tree_unflatten(list(arrays.items()))
cache_metadata = tree_unflatten(list(cache_metadata.items()))
info, metadata, classes = cache_metadata
cache = [globals()[c]() for c in classes]
for c, state, meta_state in zip(cache, arrays, info):
c.state = state
c.meta_state = meta_state
if return_metadata:
return cache, metadata
return cache
def can_trim_prompt_cache(cache: List[Any]) -> bool:
"""
Check if model's cache can be trimmed.
"""
return all(c.is_trimmable() for c in cache)
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
"""
Trim the model's cache by the given number of tokens.
This function will trim the cache if possible (in-place) and return the
number of tokens that were trimmed.
Args:
cache (List[Any]): The model's cache.
num_tokens (int): The number of tokens to trim.
Returns:
(int): The number of tokens that were trimmed.
"""
if not can_trim_prompt_cache(cache) or len(cache) == 0:
return 0
return [c.trim(num_tokens) for c in cache][0]
class _BaseCache:
@property
def state(self):
return []
@state.setter
def state(self, v):
if v is not None and v:
raise ValueError("This cache has no state but a state was set.")
@property
def meta_state(self):
return ""
@meta_state.setter
def meta_state(self, v):
if v is not None and v:
raise ValueError("This cache has no meta_state but a meta_state was set.")
def is_trimmable(self):
return False
class KVCache(_BaseCache):
def __init__(self):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B, n_kv_heads, _, k_head_dim = keys.shape
v_head_dim = values.shape[3]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@property
def state(self):
if self.offset == self.keys.shape[2]:
return self.keys, self.values
else:
return (
self.keys[..., : self.offset, :],
self.values[..., : self.offset, :],
)
@state.setter
def state(self, v):
self.keys, self.values = v
self.offset = self.keys.shape[2]
def is_trimmable(self):
return True
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n
class RotatingKVCache(_BaseCache):
def __init__(self, max_size=None, keep=0, step=256):
self.keep = keep
self.keys = None
self.values = None
self.offset = 0
self.max_size = max_size
self.step = step
self._idx = 0
def _trim(self, trim_size, v, append=None):
to_cat = []
if trim_size > 0:
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
else:
to_cat = [v]
if append is not None:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def _temporal_order(self, v):
"""
Rearrange the cache into temporal order, slicing off the end if unused.
"""
if self._idx == v.shape[2]:
return v
elif self._idx < self.offset:
return mx.concatenate(
[
v[..., : self.keep, :],
v[..., self._idx :, :],
v[..., self.keep : self._idx, :],
],
axis=2,
)
else:
return v[..., : self._idx, :]
def _update_concat(self, keys, values):
if self.keys is None:
self.keys = keys
self.values = values
else:
# Put the keys/values in temporal order to
# preserve context
self.keys = self._temporal_order(self.keys)
self.values = self._temporal_order(self.values)
# The largest size is self.max_size + S - 1 to ensure
# every token gets at least self.max_size context
trim_size = self._idx - self.max_size + 1
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += keys.shape[2]
self._idx = self.keys.shape[2]
return self.keys, self.values
def _update_in_place(self, keys, values):
# May not have hit the max size yet, so potentially
# keep growing the cache
B, n_kv_heads, S, k_head_dim = keys.shape
prev = self.offset
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
):
v_head_dim = values.shape[3]
new_size = min(self.step, self.max_size - prev)
k_shape = (B, n_kv_heads, new_size, k_head_dim)
v_shape = (B, n_kv_heads, new_size, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self._idx = prev
# Trim if needed
trim_size = self.keys.shape[2] - self.max_size
if trim_size > 0:
self.keys = self._trim(trim_size, self.keys)
self.values = self._trim(trim_size, self.values)
self._idx = self.max_size
# Rotate
if self._idx == self.max_size:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + S, :] = keys
self.values[..., self._idx : self._idx + S, :] = values
self.offset += S
self._idx += S
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
def update_and_fetch(self, keys, values):
if keys.shape[2] == 1:
return self._update_in_place(keys, values)
return self._update_concat(keys, values)
@property
def state(self):
if self.offset < self.keys.shape[2]:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
else:
return self.keys, self.values
@state.setter
def state(self, v):
self.keys, self.values = v
@property
def meta_state(self):
return tuple(
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
)
@meta_state.setter
def meta_state(self, v):
self.keep, self.max_size, self.step, self.offset, self._idx = map(
int,
v,
)
def is_trimmable(self):
return self.offset < self.max_size
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
self._idx -= n
return n
class MambaCache(_BaseCache):
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
@state.setter
def state(self, v):
self.cache = v

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -69,7 +69,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -129,7 +129,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.input_layernorm(x)
attn_h = self.self_attn(h, mask, cache)
@@ -190,11 +190,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -49,7 +49,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
qkv = self.Wqkv(x)
@@ -92,7 +92,7 @@ class NormAttnNorm(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
x = h + x
@@ -179,7 +179,7 @@ class DecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r, h = self.norm_attn_norm(x, mask, cache)
out = self.ffn(h) + r
@@ -249,11 +249,3 @@ class Model(nn.Module):
experts = [(s, sv.T) for s, sv in experts]
new_weights.update(experts)
return new_weights
@property
def head_dim(self):
return self.args.d_model // self.args.n_heads
@property
def n_kv_heads(self):
return self.args.attn_config["kv_n_heads"]

View File

@@ -1,10 +1,10 @@
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@@ -77,7 +77,7 @@ class DeepseekAttention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
@@ -108,8 +108,8 @@ class DeepseekMLP(nn.Module):
def __init__(
self,
config: ModelArgs,
hidden_size: int | None = None,
intermediate_size: int | None = None,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
):
super().__init__()
self.config = config
@@ -188,7 +188,7 @@ class DeepseekDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -210,7 +210,7 @@ class DeepseekModel(nn.Module):
def __call__(
self,
x: mx.array,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
@@ -235,7 +235,7 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
@@ -256,11 +256,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -2,12 +2,12 @@
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@@ -38,7 +38,7 @@ class ModelArgs(BaseModelArgs):
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Optional[Dict] = None
rope_scaling: Dict = None
attention_bias: bool = False
@@ -172,12 +172,11 @@ class DeepseekV2Attention(nn.Module):
bias=config.attention_bias,
)
if self.config.rope_scaling is not None:
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
rope_kwargs = {
key: self.config.rope_scaling[key]
@@ -202,7 +201,7 @@ class DeepseekV2Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -221,17 +220,17 @@ class DeepseekV2Attention(nn.Module):
k_nope, values = mx.split(kv, [self.qk_nope_head_dim], axis=-1)
k_pe = mx.concatenate([k_pe] * self.num_heads, axis=1)
if cache is not None:
q_pe = self.rope(q_pe, cache.offset)
k_pe = self.rope(k_pe, cache.offset)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys, values = cache.update_and_fetch(
mx.concatenate([k_nope, k_pe], axis=-1), values
)
else:
q_pe = self.rope(q_pe)
k_pe = self.rope(k_pe)
k_pe = mx.repeat(k_pe, self.num_heads, axis=1)
keys = mx.concatenate([k_nope, k_pe], axis=-1)
queries = mx.concatenate([q_nope, q_pe], axis=-1)
@@ -292,7 +291,7 @@ class MoEGate(nn.Module):
scores = scores.reshape(bsz, seq_len, -1)
k = self.top_k
inds = mx.stop_gradient(mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k])
inds = mx.argpartition(-scores, kth=k - 1, axis=-1)[..., :k]
scores = mx.take_along_axis(scores, inds, axis=-1)
scores = scores * self.routed_scaling_factor
@@ -347,7 +346,7 @@ class DeepseekV2DecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -370,7 +369,7 @@ class DeepseekV2Model(nn.Module):
def __call__(
self,
x: mx.array,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
@@ -395,7 +394,7 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
@@ -416,14 +415,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return (
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
self.args.v_head_dim,
)
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -60,7 +60,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -113,7 +113,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -173,11 +173,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -64,7 +64,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
@@ -135,13 +135,11 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x.astype(mx.float32)), mask, cache)
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + self.post_attention_layernorm(r)
r = self.mlp(self.pre_feedforward_layernorm(h).astype(mx.float16)).astype(
mx.float32
)
r = self.mlp(self.pre_feedforward_layernorm(h))
out = h + self.post_feedforward_layernorm(r)
return out
@@ -200,11 +198,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -46,7 +46,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -100,7 +100,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
@@ -196,11 +196,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.h
@property
def head_dim(self):
return self.args.n_embd // self.args.n_head
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -57,7 +57,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -114,7 +114,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
@@ -184,11 +184,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.n_embd // self.args.n_head
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -60,7 +60,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -120,7 +120,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
residual = x
# NeoX runs attention and feedforward network in parallel.
@@ -214,11 +214,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.h
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -116,7 +116,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -171,7 +171,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attention(self.attention_norm(x), mask, cache)
h = x + r
@@ -236,11 +236,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -171,7 +171,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -233,7 +233,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -303,13 +303,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return (
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
)
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -7,6 +7,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
@@ -45,21 +46,6 @@ class ModelArgs(BaseModelArgs):
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaCache:
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
@@ -223,7 +209,7 @@ class Model(nn.Module):
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self, batch_size: int = 1):
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -85,7 +85,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
):
B, L, _ = x.shape
@@ -135,7 +135,7 @@ class DecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
@@ -205,11 +205,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -2,7 +2,7 @@
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -66,7 +66,7 @@ class MixtralAttention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -138,7 +138,7 @@ class MixtralDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -215,11 +215,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -2,12 +2,12 @@
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -94,7 +94,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
@@ -151,7 +151,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -215,13 +215,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return (
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
)
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,8 +1,8 @@
# Copyright © 2023-2024 Apple Inc.
import sys
from dataclasses import dataclass
from sys import exit
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -13,7 +13,7 @@ try:
import hf_olmo
except ImportError:
print("To run olmo install ai2-olmo: pip install ai2-olmo")
exit(1)
sys.exit(1)
@dataclass
@@ -68,7 +68,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -98,7 +98,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attend(self.att_norm(x), mask, cache)
h = x + r
@@ -174,11 +174,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.transformer.blocks
@property
def head_dim(self):
return self.args.d_model // self.args.n_heads
@property
def n_kv_heads(self):
return self.args.n_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -80,7 +80,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -152,7 +152,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.attn_norm(x), mask, cache)
h = x + r
@@ -218,11 +218,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.transformer.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_kv_heads

View File

@@ -162,19 +162,11 @@ class Model(nn.Module):
def __call__(
self,
x: mx.array,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
cache=None,
) -> mx.array:
y = self.model(x, cache)
return self.lm_head(y)
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
from .su_rope import SuScaledRotaryEmbedding
@@ -84,7 +84,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -143,7 +143,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -202,11 +202,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -3,12 +3,12 @@
import math
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Tuple, Union
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -22,14 +22,14 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: Optional[int] = None
num_key_value_heads: int
mup_attn_multiplier: float = 1.0
mup_use_scaling: bool = True
mup_embedding_multiplier: float = 10.0
mup_width_multiplier: float = 8.0
rope_embedding_base: float = 1000000
rope_position_scale: float = 1.0
blocksparse_block_size: Tuple[int] = (64,)
blocksparse_block_size: int = 64
blocksparse_num_local_blocks: int = 16
blocksparse_vert_stride: int = 8
@@ -61,7 +61,6 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_q_per_kv = n_heads // n_kv_heads
@@ -161,7 +160,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -230,7 +229,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -304,16 +303,8 @@ class Model(nn.Module):
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -173,6 +173,7 @@ class PhiMoEModel(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.args = args
self.model = PhiMoEModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True)
@@ -208,11 +209,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -168,8 +168,8 @@ class Model(nn.Module):
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
y = self.transformer(x, mask, cache)
@@ -193,11 +193,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.model_dim // self.args.num_heads
@property
def n_kv_heads(self):
return self.args.num_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
@@ -62,8 +62,8 @@ class Attention(nn.Module):
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
cache: Optional[Any] = None,
) -> mx.array:
bsz, q_len, _ = hidden_states.shape
queries = self.q_proj(hidden_states)
@@ -89,6 +89,9 @@ class Attention(nn.Module):
queries = self.rotary_emb(queries)
keys = self.rotary_emb(keys)
keys = mx.tile(keys, [1, self.config.n_shared_head, 1, 1])
values = mx.tile(values, [1, self.config.n_shared_head, 1, 1])
output = mx.fast.scaled_dot_product_attention(
queries,
keys,
@@ -127,8 +130,8 @@ class PlamoDecoderLayer(nn.Module):
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[Any, ...]:
cache: Optional[Any] = None,
):
# from LlamaDecoder
residual = hidden_states
@@ -169,8 +172,8 @@ class PlamoModel(nn.Module):
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None,
) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]:
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
@@ -197,19 +200,11 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
) -> Tuple[mx.array, mx.array]:
cache: Optional[Any] = None,
) -> mx.array:
out = self.model(inputs, cache)
return self.lm_head(out)
@property
def layers(self):
return self.model.layers.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_attention_heads // self.args.n_shared_head

View File

@@ -1,7 +1,6 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -149,19 +148,11 @@ class Model(nn.Module):
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
cache=None,
) -> mx.array:
y = self.transformer(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_attention_heads

View File

@@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -70,7 +70,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -124,7 +124,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -196,11 +196,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -2,12 +2,12 @@
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@@ -70,7 +70,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -162,7 +162,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -236,11 +236,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -7,13 +7,13 @@ from typing import List, Literal, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
from .cache import MambaCache, RotatingKVCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
attention_bias: bool
conv1d_width: int
hidden_size: int
@@ -36,59 +36,6 @@ class ModelArgs(BaseModelArgs):
self.block_types = self._block_types
def create_window_causal_mask(N: int, window_size: int):
inds = mx.arange(N)
linds = inds[:, None]
rinds = inds[None]
mask = (linds < rinds) | (linds > rinds + window_size)
return mask * -1e9
class RecurrentCache:
def __init__(self):
self._cache = (None, None)
def __getitem__(self, idx):
return self._cache[idx]
def update(self, conv_state, recurrent_state):
self._cache = (conv_state, recurrent_state)
def state(self):
return self._cache
class WindowKVCache:
def __init__(self, window_size):
self.keys = None
self.values = None
self.offset = 0
self.window_size = window_size
def update_and_fetch(self, keys, values):
# TODO consider using rotating buffer here
# especially for very long generations
def _update(x, v):
t = x.shape[2] - self.window_size
if t > 0:
x = x[..., t:, :]
return mx.concatenate([x, v], axis=2)
self.offset += keys.shape[2]
if self.keys is None:
self.keys = keys
self.values = values
else:
self.keys = _update(self.keys, keys)
self.values = _update(self.values, values)
return self.keys, self.values
def state(self):
return self.keys, self.values
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
@@ -136,31 +83,22 @@ class Conv1d(nn.Module):
kernel_size: int,
):
super().__init__()
self.weight = mx.zeros((kernel_size, channels))
self.weight = mx.zeros((channels, kernel_size, 1))
self.bias = mx.zeros((channels,))
def __call__(self, x, cache=None):
w = self.weight.T[..., None]
kw, groups = self.weight.shape
if cache is not None:
l = []
# Pad the cache if needed
if cache.shape[1] < kw - 1:
l.append(
mx.zeros(
(x.shape[0], kw - 1 - cache.shape[1], groups), dtype=x.dtype
)
)
l.extend([cache, x])
x = mx.concatenate(l, axis=1)
y = (x * w.swapaxes(0, 2)).sum(axis=1, keepdims=True)
else:
y = mx.conv_general(x, w, padding=([kw - 1], [0]), groups=groups)
B, L, C = x.shape
groups, K, _ = self.weight.shape
# The cache is always kw - 1
cache = x[:, max(x.shape[1] - kw + 1, 0) :, :]
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=groups)
y = y + self.bias
return y, cache
return y, x[:, -K + 1 :, :]
class RGLRU(nn.Module):
@@ -269,19 +207,9 @@ class RecurrentBlock(nn.Module):
# x branch.
x = self.linear_x(x)
if cache is None:
conv_state, recurrent_state = (None, None)
else:
conv_state, recurrent_state = cache[0], cache[1]
x, conv_state = self.conv_1d(
x=x,
cache=conv_state,
)
x, recurrent_state = self.rg_lru(
x=x,
cache=recurrent_state,
)
if cache is not None:
cache.update(conv_state, recurrent_state)
cache = [None, None]
x, cache[0] = self.conv_1d(x=x, cache=cache[0])
x, cache[1] = self.rg_lru(x=x, cache=cache[1])
x = x * y
x = self.linear_out(x)
@@ -467,12 +395,14 @@ class Griffin(nn.Module):
if self.scale_by_sqrt_dim:
x = x * math.sqrt(x.shape[-1])
mask = None
if x.shape[1] > 1:
mask = create_window_causal_mask(
x.shape[1], self.config.attention_window_size
)
mask = mask.astype(x.dtype)
if cache is None:
cache = [None] * len(self.layers)
for i, block in enumerate(self.layers):
if block.temporal_block_type != "recurrent":
mask_cache = [cache[i]]
mask = create_attention_mask(x, mask_cache)
for i, block in enumerate(self.layers):
x = block(x, mask=mask, cache=cache[i])
@@ -485,6 +415,7 @@ class Model(nn.Module):
def __init__(self, config):
self.args = config
self.model = Griffin(config)
self.model_type = config.model_type
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
@@ -508,10 +439,9 @@ class Model(nn.Module):
return self.model.layers
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
for k, v in weights.items():
if "conv_1d.weight" in k and v.ndim == 3:
weights[k] = v.squeeze(1).T
weights[k] = v.moveaxis(2, 1)
if "lm_head.weight" not in weights:
self.pop("lm_head")
return weights
@@ -520,7 +450,7 @@ class Model(nn.Module):
cache = []
for layer in self.layers:
if layer.temporal_block_type == "recurrent":
cache.append(RecurrentCache())
cache.append(MambaCache())
else:
cache.append(WindowKVCache(self.args.attention_window_size))
cache.append(RotatingKVCache(max_size=self.args.attention_window_size))
return cache

View File

@@ -2,7 +2,6 @@
import math
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -198,8 +197,8 @@ class Model(nn.Module):
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)
@@ -207,11 +206,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -45,7 +45,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -100,7 +100,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -164,11 +164,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -3,19 +3,38 @@
import argparse
import json
import logging
import platform
import time
import uuid
import warnings
from dataclasses import dataclass, field
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Dict, List, Literal, NamedTuple, Optional, Sequence, Union
from typing import (
Any,
Dict,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
import mlx.core as mx
from huggingface_hub import scan_cache_dir
from ._version import __version__
from .models.cache import make_prompt_cache
from .utils import generate_step, load
def get_system_fingerprint():
gpu_arch = mx.metal.device_info()["architecture"] if mx.metal.is_available() else ""
return f"{__version__}-{mx.__version__}-{platform.platform()}-{gpu_arch}"
class StopCondition(NamedTuple):
stop_met: bool
trim_length: int
@@ -94,6 +113,13 @@ def convert_chat(messages: List[dict], role_mapping: Optional[dict] = None):
return prompt.rstrip()
@dataclass
class PromptCache:
cache: List[Any] = field(default_factory=list)
model_key: Tuple[str, Optional[str]] = ("", None)
tokens: List[int] = field(default_factory=list)
class ModelProvider:
def __init__(self, cli_args: argparse.Namespace):
"""Load models on demand and persist them across the whole process."""
@@ -156,12 +182,21 @@ class ModelProvider:
class APIHandler(BaseHTTPRequestHandler):
def __init__(self, model_provider: ModelProvider, *args, **kwargs):
def __init__(
self,
model_provider: ModelProvider,
*args,
prompt_cache: Optional[PromptCache] = None,
system_fingerprint: Optional[str] = None,
**kwargs,
):
"""
Create static request specific metadata
"""
self.created = int(time.time())
self.model_provider = model_provider
self.prompt_cache = prompt_cache or PromptCache()
self.system_fingerprint = system_fingerprint or get_system_fingerprint()
super().__init__(*args, **kwargs)
def _set_cors_headers(self):
@@ -215,7 +250,9 @@ class APIHandler(BaseHTTPRequestHandler):
self.stream_options = self.body.get("stream_options", None)
self.requested_model = self.body.get("model", "default_model")
self.adapter = self.body.get("adapters", None)
self.max_tokens = self.body.get("max_tokens", 100)
self.max_tokens = self.body.get("max_completion_tokens", None)
if self.max_tokens is None:
self.max_tokens = self.body.get("max_tokens", 512)
self.temperature = self.body.get("temperature", 1.0)
self.top_p = self.body.get("top_p", 1.0)
self.repetition_penalty = self.body.get("repetition_penalty", 1.0)
@@ -343,7 +380,7 @@ class APIHandler(BaseHTTPRequestHandler):
# Static response
response = {
"id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"system_fingerprint": self.system_fingerprint,
"object": self.object_type,
"model": self.requested_model,
"created": self.created,
@@ -388,16 +425,30 @@ class APIHandler(BaseHTTPRequestHandler):
return response
def get_prompt_cache(self, prompt):
cache_len = len(self.prompt_cache.tokens)
if (
self.prompt_cache.model_key != self.model_provider.model_key
or cache_len >= len(prompt)
or self.prompt_cache.tokens != prompt[:cache_len]
):
self.prompt_cache.model_key = self.model_provider.model_key
self.prompt_cache.cache = make_prompt_cache(self.model_provider.model)
else:
prompt = prompt[cache_len:]
self.prompt_cache.tokens.extend(prompt)
return prompt
def handle_completion(
self,
prompt: mx.array,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
Generate a response to a prompt and send it to the client in a single batch.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
prompt (List[int]): The tokenized prompt.
stop_id_sequences (List[List[int]]): A list of stop words passed
to the stopping_criteria function
"""
@@ -409,17 +460,21 @@ class APIHandler(BaseHTTPRequestHandler):
logging.debug(f"Starting completion:")
token_logprobs = []
top_tokens = []
for (token, logprobs), _ in zip(
prompt = self.get_prompt_cache(prompt)
for _, (token, logprobs) in zip(
range(self.max_tokens),
generate_step(
prompt=prompt,
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
logit_bias=self.logit_bias,
prompt_cache=self.prompt_cache.cache,
),
range(self.max_tokens),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
@@ -430,7 +485,7 @@ class APIHandler(BaseHTTPRequestHandler):
top_indices = sorted_indices[: self.logprobs]
top_logprobs = logprobs[top_indices]
top_token_info = zip(top_indices.tolist(), top_logprobs.tolist())
top_tokens.append(dict(top_token_info))
top_tokens.append(tuple(top_token_info))
token_logprobs.append(logprobs[token].item())
@@ -445,6 +500,7 @@ class APIHandler(BaseHTTPRequestHandler):
)
break
self.prompt_cache.tokens.extend(tokens)
detokenizer.finalize()
text = (
detokenizer.text
@@ -474,7 +530,7 @@ class APIHandler(BaseHTTPRequestHandler):
def handle_stream(
self,
prompt: mx.array,
prompt: List[int],
stop_id_sequences: List[List[int]],
):
"""
@@ -482,7 +538,7 @@ class APIHandler(BaseHTTPRequestHandler):
Sent Events (SSE) stream.
Args:
prompt (mx.array): The prompt, in token form inside of a mlx array
prompt (mx.array): The tokenized prompt
stop_id_sequences (List[List[int]]): A list of stop words passed to
the stopping_criteria function
"""
@@ -496,16 +552,19 @@ class APIHandler(BaseHTTPRequestHandler):
stop_sequence_suffix = None
logging.debug(f"Starting stream:")
for (token, _), _ in zip(
prompt = self.get_prompt_cache(prompt)
for _, (token, _) in zip(
range(self.max_tokens),
generate_step(
prompt=prompt,
prompt=mx.array(prompt),
model=self.model,
temp=self.temperature,
top_p=self.top_p,
repetition_penalty=self.repetition_penalty,
repetition_context_size=self.repetition_context_size,
prompt_cache=self.prompt_cache.cache,
),
range(self.max_tokens),
):
detokenizer.add_token(token)
logging.debug(detokenizer.text)
@@ -531,9 +590,12 @@ class APIHandler(BaseHTTPRequestHandler):
continue
new_text = detokenizer.last_segment
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
if new_text:
response = self.generate_response(new_text, None)
self.wfile.write(f"data: {json.dumps(response)}\n\n".encode())
self.wfile.flush()
self.prompt_cache.tokens.extend(tokens)
# check is there any remaining text to send
detokenizer.finalize()
@@ -559,7 +621,7 @@ class APIHandler(BaseHTTPRequestHandler):
):
response = {
"id": self.request_id,
"system_fingerprint": f"fp_{uuid.uuid4()}",
"system_fingerprint": self.system_fingerprint,
"object": "chat.completion",
"model": self.requested_model,
"created": self.created,
@@ -572,7 +634,7 @@ class APIHandler(BaseHTTPRequestHandler):
}
return response
def handle_chat_completions(self) -> mx.array:
def handle_chat_completions(self) -> List[int]:
"""
Handle a chat completion request.
@@ -587,7 +649,6 @@ class APIHandler(BaseHTTPRequestHandler):
self.object_type = (
"chat.completions.chunk" if self.stream else "chat.completions"
)
if (
hasattr(self.tokenizer, "apply_chat_template")
and self.tokenizer.chat_template
@@ -602,9 +663,9 @@ class APIHandler(BaseHTTPRequestHandler):
prompt = convert_chat(body["messages"], body.get("role_mapping"))
prompt = self.tokenizer.encode(prompt)
return mx.array(prompt)
return prompt
def handle_text_completions(self) -> mx.array:
def handle_text_completions(self) -> List[int]:
"""
Handle a text completion request.
@@ -614,11 +675,8 @@ class APIHandler(BaseHTTPRequestHandler):
# Determine response type
self.request_id = f"cmpl-{uuid.uuid4()}"
self.object_type = "text_completion"
assert "prompt" in self.body, "Request did not contain a prompt"
prompt_text = self.body["prompt"]
prompt = self.tokenizer.encode(prompt_text)
return mx.array(prompt)
return self.tokenizer.encode(self.body["prompt"])
def do_GET(self):
"""
@@ -669,9 +727,16 @@ def run(
handler_class=APIHandler,
):
server_address = (host, port)
prompt_cache = PromptCache()
httpd = server_class(
server_address,
lambda *args, **kwargs: handler_class(model_provider, *args, **kwargs),
lambda *args, **kwargs: handler_class(
model_provider,
prompt_cache=prompt_cache,
system_fingerprint=get_system_fingerprint(),
*args,
**kwargs,
),
)
warnings.warn(
"mlx_lm.server is not recommended for production as "

View File

@@ -97,6 +97,11 @@ class NaiveStreamingDetokenizer(StreamingDetokenizer):
def text(self):
if self._current_tokens:
self._current_text = self._tokenizer.decode(self._current_tokens)
if (
self._tokenizer.clean_up_tokenization_spaces
and self._current_text[-1] == " "
):
self._current_text = self._current_text[:-1]
if self._current_text and self._current_text[-1] == "\n":
self._tokens.extend(self._current_tokens)
self._text += self._current_text
@@ -164,9 +169,11 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
"""
_byte_decoder = None
_space_matches = (".", "?", "!", ",", "'", "n't", "'m", "'s", "'ve", "'re")
def __init__(self, tokenizer, trim_space=False):
self.trim_space = trim_space
def __init__(self, tokenizer):
self.clean_spaces = tokenizer.clean_up_tokenization_spaces
# Extract the tokens in a list from id to text
self.tokenmap = [None] * len(tokenizer.vocab)
@@ -185,17 +192,22 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
self.text = ""
self.tokens = []
def _maybe_trim_space(self, current_text):
if current_text[0] != " ":
return current_text
elif not self.text:
return current_text[1:]
elif self.clean_spaces and current_text[1:].startswith(self._space_matches):
return current_text[1:]
return current_text
def add_token(self, token):
v = self.tokenmap[token]
# if the token starts with space
if self._byte_decoder[v[0]] == 32:
current_text = bytearray(
self._byte_decoder[c] for c in self._unflushed
).decode("utf-8")
if self.text or not self.trim_space:
self.text += current_text
else:
self.text += _remove_space(current_text)
self.text += self._maybe_trim_space(current_text)
self._unflushed = v
else:
self._unflushed += v
@@ -204,10 +216,7 @@ class BPEStreamingDetokenizer(StreamingDetokenizer):
current_text = bytearray(self._byte_decoder[c] for c in self._unflushed).decode(
"utf-8"
)
if self.text or not self.trim_space:
self.text += current_text
else:
self.text += _remove_space(current_text)
self.text += self._maybe_trim_space(current_text)
self._unflushed = ""
@classmethod
@@ -303,14 +312,7 @@ def _is_spm_decoder_no_space(decoder):
def _is_bpe_decoder(decoder):
_target_description = {
"type": "ByteLevel",
"add_prefix_space": False,
"trim_offsets": False,
"use_regex": False,
}
return _match(_target_description, decoder)
return isinstance(decoder, dict) and decoder.get("type", None) == "ByteLevel"
def load_tokenizer(model_path, tokenizer_config_extra={}):

View File

@@ -18,7 +18,7 @@ from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
# Local imports
from .models.base import KVCache, RotatingKVCache
from .models import base, cache
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
from .tuner.utils import dequantize as dequantize_model
@@ -124,26 +124,6 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
return logits
def make_kv_caches(
model: nn.Module, max_kv_size: Optional[int] = None
) -> List[Union[KVCache, RotatingKVCache]]:
if hasattr(model, "make_cache"):
return model.make_cache()
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
if max_kv_size is not None:
return [
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
for n in kv_heads
]
else:
return [KVCache(model.head_dim, n) for n in kv_heads]
def generate_step(
prompt: mx.array,
model: nn.Module,
@@ -155,7 +135,7 @@ def generate_step(
min_tokens_to_keep: int = 1,
prefill_step_size: int = 512,
max_kv_size: Optional[int] = None,
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
prompt_cache: Optional[Any] = None,
logit_bias: Optional[Dict[int, float]] = None,
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
@@ -180,6 +160,8 @@ def generate_step(
prefill_step_size (int): Step size for processing the prompt.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
logit_bias (dictionary, optional): Additive logit bias.
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
@@ -237,20 +219,13 @@ def generate_step(
tokens = None
# Create the KV cache for generation
cache = make_kv_caches(model, max_kv_size)
if cache_history is not None:
if len(cache_history) != len(cache):
raise ValueError("Wrong number of layers in the cache history")
# Set the history in the cache objects and evaluate them to prepare for
# generation.
for c, h in zip(cache, cache_history):
c.update_and_fetch(h[0], h[1])
mx.eval([c.state for c in cache])
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
def _step(y):
logits = model(y[None], cache=cache)
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
if logits_processor:
@@ -264,16 +239,17 @@ def generate_step(
return y, logprobs.squeeze(0)
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=cache)
mx.eval([c.state for c in cache])
model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
y = y[prefill_step_size:]
mx.metal.clear_cache()
y, logprobs = _step(y)
mx.async_eval(y)
mx.async_eval(y, logprobs)
while True:
next_y, next_logprobs = _step(y)
mx.async_eval(next_y)
mx.async_eval(next_y, next_logprobs)
yield y.item(), logprobs
y, logprobs = next_y, next_logprobs
@@ -305,9 +281,9 @@ def stream_generate(
detokenizer = tokenizer.detokenizer
detokenizer.reset()
for (token, _), n in zip(
generate_step(prompt_tokens, model, **kwargs),
for n, (token, _) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if token == tokenizer.eos_token_id:
break
@@ -357,9 +333,9 @@ def generate(
tic = time.perf_counter()
detokenizer.reset()
for (token, logprobs), n in zip(
generate_step(prompt_tokens, model, **kwargs),
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic
@@ -372,7 +348,9 @@ def generate(
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item())
with mx.stream(mx.cpu):
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)

View File

@@ -32,6 +32,7 @@ setup(
entry_points={
"console_scripts": [
"mlx_lm.cache_prompt = mlx_lm.cache_prompt:main",
"mlx_lm.chat = mlx_lm.chat:main",
"mlx_lm.convert = mlx_lm.convert:main",
"mlx_lm.fuse = mlx_lm.fuse:main",
"mlx_lm.generate = mlx_lm.generate:main",

View File

@@ -1,17 +1,15 @@
# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
from mlx.utils import tree_map
from mlx_lm.models.base import KVCache, RotatingKVCache
from mlx_lm.utils import make_kv_caches
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
class TestModels(unittest.TestCase):
def test_kv_cache(self):
cache = KVCache(32, 4)
cache = KVCache()
k = mx.ones((1, 4, 1, 32), mx.float16)
v = mx.ones((1, 4, 1, 32), mx.float16)
@@ -32,7 +30,7 @@ class TestModels(unittest.TestCase):
def test_rotating_kv_cache(self):
b, h, d = 1, 2, 32
cache = RotatingKVCache(d, h, max_size=8, step=4)
cache = RotatingKVCache(max_size=8, step=4)
k = mx.random.uniform(shape=(b, h, 2, d))
v = mx.random.uniform(shape=(b, h, 2, d))
@@ -65,7 +63,7 @@ class TestModels(unittest.TestCase):
idx %= 8
# Try with nonzero keep
cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2)
cache = RotatingKVCache(max_size=8, step=4, keep=2)
# Check a large update
k = mx.random.uniform(shape=(b, h, 20, d))
@@ -88,6 +86,46 @@ class TestModels(unittest.TestCase):
if idx >= 8:
idx = 2
def test_rotating_kv_cache_chat_mode(self):
# Test that the rotating kv cache can handle
# alternating prompt/prefill with generation
d = 4
h = 2
cache = RotatingKVCache(max_size=18, step=4)
x = mx.random.uniform(shape=(1, h, 8, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 8)
self.assertEqual(cache.offset, 8)
x = mx.random.uniform(shape=(1, h, 1, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 9)
self.assertEqual(cache.offset, 9)
self.assertTrue(mx.allclose(x, k[..., 8:9, :]))
x = mx.random.uniform(shape=(1, h, 2, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 11)
self.assertEqual(cache.offset, 11)
self.assertTrue(mx.allclose(x, k[..., 9:11, :]))
x = mx.random.uniform(shape=(1, h, 3, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 14)
self.assertEqual(cache.offset, 14)
self.assertTrue(mx.allclose(x, k[..., 11:14, :]))
x = mx.random.uniform(shape=(1, h, 6, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(cache.offset, 20)
self.assertTrue(mx.allclose(x, k[..., -6:, :]))
x = mx.random.uniform(shape=(1, h, 2, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(cache.offset, 22)
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
def model_test_runner(self, model, model_type, vocab_size, num_layers):
self.assertEqual(len(model.layers), num_layers)
@@ -101,7 +139,7 @@ class TestModels(unittest.TestCase):
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
cache = make_kv_caches(model)
cache = make_prompt_cache(model)
outputs = model(inputs, cache)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
@@ -549,6 +587,179 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_deepseek(self):
from mlx_lm.models import deepseek
args = deepseek.ModelArgs(
model_type="deepseek",
vocab_size=1024,
hidden_size=128,
intermediate_size=256,
moe_intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=8,
num_key_value_heads=4,
)
model = deepseek.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_deepseek_v2(self):
from mlx_lm.models import deepseek_v2
args = deepseek_v2.ModelArgs(
model_type="deepseek_v2",
vocab_size=1024,
hidden_size=128,
intermediate_size=256,
moe_intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
kv_lora_rank=4,
q_lora_rank=4,
qk_rope_head_dim=32,
v_head_dim=16,
qk_nope_head_dim=32,
rope_scaling={
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn",
},
)
model = deepseek_v2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gemma2(self):
from mlx_lm.models import gemma2
args = gemma2.ModelArgs(
model_type="gemma2",
hidden_size=128,
num_hidden_layers=4,
intermediate_size=256,
num_attention_heads=2,
head_dim=32,
rms_norm_eps=1e-4,
vocab_size=1024,
num_key_value_heads=2,
)
model = gemma2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gpt_bigcode(self):
from mlx_lm.models import gpt_bigcode
args = gpt_bigcode.ModelArgs(
model_type="gpt_bigcode",
n_embd=128,
n_layer=128,
n_inner=256,
n_head=4,
n_positions=1000,
layer_norm_epsilon=1e-5,
vocab_size=1024,
)
model = gpt_bigcode.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
def test_nemotron(self):
from mlx_lm.models import nemotron
args = nemotron.ModelArgs(
model_type="nemotron",
hidden_size=128,
hidden_act="gelu",
num_hidden_layers=4,
intermediate_size=256,
num_attention_heads=4,
norm_eps=1e-5,
vocab_size=1024,
num_key_value_heads=2,
)
model = nemotron.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phi3small(self):
from mlx_lm.models import phi3small
args = phi3small.ModelArgs(
model_type="phi3small",
hidden_size=128,
dense_attention_every_n_layers=2,
ff_intermediate_size=256,
gegelu_limit=1.0,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
layer_norm_epsilon=1e-4,
vocab_size=1000,
)
model = phi3small.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phimoe(self):
from mlx_lm.models import phimoe
args = phimoe.ModelArgs(
model_type="phimoe",
vocab_size=320,
hidden_size=128,
intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=4,
rope_scaling={
"long_factor": [1.0] * 16,
"long_mscale": 1.243163121016122,
"original_max_position_embeddings": 4096,
"short_factor": [1.0] * 16,
"short_mscale": 1.243163121016122,
"type": "longrope",
},
)
model = phimoe.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_recurrent_gemma(self):
from mlx_lm.models import recurrent_gemma
args = recurrent_gemma.ModelArgs(
model_type="recurrent_gemma",
hidden_size=128,
attention_bias=False,
conv1d_width=3,
intermediate_size=256,
logits_soft_cap=1.0,
num_attention_heads=4,
num_hidden_layers=4,
num_key_value_heads=2,
rms_norm_eps=1e-4,
rope_theta=1000,
attention_window_size=1024,
vocab_size=1000,
block_types=["recurrent", "recurrent", "attention"],
)
model = recurrent_gemma.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,243 @@
# Copyright © 2024 Apple Inc.
import copy
import os
import tempfile
import unittest
import mlx.core as mx
from mlx_lm.models.cache import (
KVCache,
MambaCache,
RotatingKVCache,
load_prompt_cache,
make_prompt_cache,
save_prompt_cache,
trim_prompt_cache,
)
from mlx_lm.utils import generate_step, load
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
class TestPromptCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
def test_save_load(self):
cache = [KVCache() for _ in range(4)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 4))
c.update_and_fetch(x, x)
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
save_prompt_cache(cache_file, cache)
loaded_cache = load_prompt_cache(cache_file)
self.assertTrue(len(cache), len(loaded_cache))
for c, lc in zip(cache, loaded_cache):
self.assertEqual(c.offset, lc.offset)
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
# Test with metadata
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
metadata = {"a": "b", "c": "d"}
save_prompt_cache(cache_file, cache, metadata)
_, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True)
self.assertEqual(metadata, loaded_metadata)
def test_save_load_rotating_cache(self):
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
# Test with rotating cache
cache = [RotatingKVCache(max_size=8, keep=2) for _ in range(4)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 4))
c.update_and_fetch(x, x)
save_prompt_cache(cache_file, cache)
loaded_cache = load_prompt_cache(cache_file)
self.assertTrue(len(cache), len(loaded_cache))
for c, lc in zip(cache, loaded_cache):
self.assertEqual(c.offset, lc.offset)
self.assertEqual(c.keep, lc.keep)
self.assertEqual(c.max_size, lc.max_size)
self.assertEqual(c.step, lc.step)
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
# Do a couple single token updates to get a rotation
for _ in range(2):
for c in cache:
x = mx.random.uniform(shape=(1, 8, 1, 4))
c.update_and_fetch(x, x)
save_prompt_cache(cache_file, cache)
loaded_cache = load_prompt_cache(cache_file)
for c, lc in zip(cache, loaded_cache):
x = mx.random.uniform(shape=(1, 8, 1, 4))
k, v = c.update_and_fetch(x, x)
lk, lv = lc.update_and_fetch(x, x)
self.assertEqual(c.offset, lc.offset)
self.assertTrue(mx.array_equal(k, lk))
self.assertTrue(mx.array_equal(v, lv))
def test_save_load_mixed_cache(self):
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
cache = [MambaCache(), KVCache(), RotatingKVCache(8), MambaCache()]
for c in cache:
if isinstance(c, MambaCache):
c[0] = mx.random.uniform(shape=(4, 4, 4))
c[1] = mx.random.uniform(shape=(4, 4, 4))
else:
x = mx.random.uniform(shape=(4, 4, 7, 4))
y = mx.random.uniform(shape=(4, 4, 7, 4))
c.update_and_fetch(x, y)
save_prompt_cache(cache_file, cache)
loaded_cache = load_prompt_cache(cache_file)
for c, lc in zip(cache, loaded_cache):
if isinstance(c, MambaCache):
self.assertTrue(mx.array_equal(c[0], lc[0]))
self.assertTrue(mx.array_equal(c[1], lc[1]))
else:
x = mx.random.uniform(shape=(4, 4, 1, 4))
y = mx.random.uniform(shape=(4, 4, 1, 4))
k, v = c.update_and_fetch(x, y)
lk, lv = lc.update_and_fetch(x, y)
self.assertEqual(c.offset, lc.offset)
self.assertTrue(mx.array_equal(k, lk))
self.assertTrue(mx.array_equal(v, lv))
def test_cache_with_generate(self):
model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
results = zip(range(4), generate_step(prompt, model))
toks, all_logits = zip(*(r[1] for r in results))
prompt_cache = make_prompt_cache(model)
i = 0
for _, (tok, logits) in zip(
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
):
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i]))
i += 1
for _, (tok, logits) in zip(
range(1),
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
):
i += 1
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i]))
def test_trim_cache(self):
cache = [KVCache() for _ in range(2)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 4))
c.update_and_fetch(x, x)
# Trim
num_trimmed = trim_prompt_cache(cache, 7)
self.assertEqual(num_trimmed, 7)
# Trim more tokens than remain
num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 3)
# Can't trim mamba cache
cache = [MambaCache() for _ in range(2)]
for c in cache:
c.state = mx.zeros((5, 5))
num_trimmed = trim_prompt_cache(cache, 7)
self.assertEqual(num_trimmed, 0)
# All cache's have to be trimmable
cache = [MambaCache(), KVCache()]
cache[0].state = mx.zeros((5, 5))
x = mx.random.uniform(shape=(1, 8, 10, 4))
cache[1].update_and_fetch(x, x)
num_trimmed = trim_prompt_cache(cache, 1)
self.assertEqual(num_trimmed, 0)
cache = [RotatingKVCache(max_size=6) for _ in range(2)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 5, 4))
c.update_and_fetch(x, x)
num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 4)
# Can't trim fixed-size KV cache after processing
# more than max_kv_size tokens
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 4))
c.update_and_fetch(x, x)
num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 0)
def test_trim_cache_with_generate(self):
model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
prompt_cache = make_prompt_cache(model)
# Generate one token so we process the full prompt
last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache))
last_tok = mx.array([last_tok])
# Generate two more tokens
results = zip(
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
)
toks, all_logits = zip(*(r[1] for r in results))
# To get back to the cache just after processing the prompt,
# trim by 3 tokens
trim_prompt_cache(prompt_cache, 3)
# Generate the same thing again
results = zip(
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
)
second_toks, second_all_logits = zip(*(r[1] for r in results))
self.assertEqual(toks, second_toks)
self.assertTrue(
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
)
def test_cache_copying(self):
cache = [KVCache()]
x = mx.random.uniform(shape=(1, 8, 10, 4))
cache[0].update_and_fetch(x, x)
y = mx.random.uniform(shape=(1, 8, 1, 4))
cache[0].update_and_fetch(y, y)
old_cache = copy.deepcopy(cache)
trim_prompt_cache(cache, 1)
self.assertTrue(old_cache[0].offset, 11)
self.assertTrue(cache[0].offset, 10)
z = mx.random.uniform(shape=(1, 8, 1, 4))
cache[0].update_and_fetch(z, z)
self.assertTrue(mx.allclose(old_cache[0].keys[..., 10:11, :], y))
self.assertTrue(mx.allclose(cache[0].keys[..., 10:11, :], z))
if __name__ == "__main__":
unittest.main()

View File

@@ -14,6 +14,7 @@ class DummyModelProvider:
def __init__(self):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
self.model, self.tokenizer = load(HF_MODEL_PATH)
self.model_key = (HF_MODEL_PATH, None)
def load(self, model, adapter=None):
assert model in ["default_model", "chat_model"]

View File

@@ -0,0 +1,76 @@
# Copyright © 2024 Apple Inc.
import unittest
from pathlib import Path
from huggingface_hub import snapshot_download
from mlx_lm.tokenizer_utils import (
BPEStreamingDetokenizer,
NaiveStreamingDetokenizer,
SPMStreamingDetokenizer,
load_tokenizer,
)
class TestTokenizers(unittest.TestCase):
def download_tokenizer(self, repo):
path = Path(
snapshot_download(
repo_id=repo,
allow_patterns=[
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"tokenizer.model",
],
)
)
return load_tokenizer(path)
def check_tokenizer(self, tokenizer):
def check(tokens):
expected_text = tokenizer.decode(tokens)
detokenizer = tokenizer.detokenizer
detokenizer.reset()
text = ""
for t in tokens:
detokenizer.add_token(t)
seg = detokenizer.last_segment
text += seg
detokenizer.finalize()
text += detokenizer.last_segment
self.assertEqual(text, expected_text)
tokens = tokenizer.encode("a ,b")
check(tokens)
tokens = tokenizer.encode('{"why_its_funny" :"a_joke_explainer" ,"rating":3.5}')
check(tokens)
tokens = tokenizer.encode("3 3")
check(tokens)
def test_tokenizers(self):
tokenizer_repos = [
("mlx-community/Qwen1.5-0.5B-Chat-4bit", BPEStreamingDetokenizer),
("mlx-community/Mistral-7B-v0.2-4bit", SPMStreamingDetokenizer),
("mlx-community/Phi-3.5-mini-instruct-4bit", SPMStreamingDetokenizer),
("mlx-community/Mistral-7B-Instruct-v0.3", SPMStreamingDetokenizer),
("mlx-community/Llama-3.2-1B-Instruct-4bit", BPEStreamingDetokenizer),
]
for tokenizer_repo, expected_detokenizer in tokenizer_repos:
with self.subTest(tokenizer=tokenizer_repo):
tokenizer = self.download_tokenizer(tokenizer_repo)
tokenizer.decode([0, 1, 2])
self.assertTrue(isinstance(tokenizer.detokenizer, expected_detokenizer))
self.check_tokenizer(tokenizer)
# Try one with a naive detokenizer
tokenizer = self.download_tokenizer("mlx-community/Llama-3.2-1B-Instruct-4bit")
tokenizer._detokenizer = NaiveStreamingDetokenizer(tokenizer)
self.check_tokenizer(tokenizer)
if __name__ == "__main__":
unittest.main()