mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-18 10:26:58 +08:00
reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -6,6 +6,9 @@ __pycache__/
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Vim
|
||||
*.swp
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
|
@@ -155,14 +155,14 @@ different queries. To cache a prompt use `mlx_lm.cache_prompt`. For example:
|
||||
cat prompt.txt | mlx_lm.cache_prompt \
|
||||
--model mistralai/Mistral-7B-Instruct-v0.3 \
|
||||
--prompt - \
|
||||
--kv-cache-file mistral_prompt.safetensors
|
||||
--prompt-cache-file mistral_prompt.safetensors
|
||||
```
|
||||
|
||||
Then use the cached prompt with `mlx_lm.generate`:
|
||||
|
||||
```
|
||||
mlx_lm.generate \
|
||||
--kv-cache-file mistral_prompt.safetensors \
|
||||
--prompt-cache-file mistral_prompt.safetensors \
|
||||
--prompt "\nSummarize the above text."
|
||||
```
|
||||
|
||||
|
@@ -7,13 +7,14 @@ import time
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .utils import load, make_kv_caches
|
||||
from .models.cache import make_prompt_cache, save_prompt_cache
|
||||
from .utils import load
|
||||
|
||||
|
||||
def setup_arg_parser():
|
||||
"""Set up and return the argument parser."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Cache the KV cache of a prompt to be reused with mlx_lm.generate"
|
||||
description="Cache the state of a prompt to be reused with mlx_lm.generate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
@@ -60,7 +61,9 @@ def setup_arg_parser():
|
||||
help="Set the maximum key-value cache size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-cache-file", help="The file to save the KV caches in", required=True
|
||||
"--prompt-cache-file",
|
||||
help="The file to save the prompt cache in",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
@@ -115,7 +118,7 @@ def main():
|
||||
else:
|
||||
prompt = args.prompt
|
||||
|
||||
cache = make_kv_caches(model, args.max_kv_size)
|
||||
cache = make_prompt_cache(model, args.max_kv_size)
|
||||
y = mx.array(tokenizer.encode(prompt))
|
||||
|
||||
# Process the prompt
|
||||
@@ -137,16 +140,12 @@ def main():
|
||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
||||
|
||||
print("Saving...")
|
||||
cache_dict = {}
|
||||
for i, c in enumerate(cache):
|
||||
cache_dict[f"{i}_keys"] = c.state[0][..., : c.offset, :]
|
||||
cache_dict[f"{i}_values"] = c.state[1][..., : c.offset, :]
|
||||
metadata = {}
|
||||
metadata["model"] = args.model
|
||||
metadata["chat_template"] = tokenizer.chat_template
|
||||
metadata["tokenizer_config"] = json.dumps(tokenizer_config)
|
||||
metadata["max_kv_size"] = str(args.max_kv_size)
|
||||
mx.save_safetensors(args.kv_cache_file, cache_dict, metadata)
|
||||
print(f"Peak memory: {mx.metal.get_peak_memory() / 2**30:.3f} GB")
|
||||
save_prompt_cache(args.prompt_cache_file, cache, metadata)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
50
llms/mlx_lm/examples/chat.py
Normal file
50
llms/mlx_lm/examples/chat.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
"""
|
||||
An example of a multi-turn chat with prompt caching.
|
||||
"""
|
||||
|
||||
from mlx_lm import generate, load
|
||||
from mlx_lm.models.cache import make_prompt_cache
|
||||
|
||||
model, tokenizer = load("mlx-community/Mistral-7B-Instruct-v0.3-4bit")
|
||||
|
||||
# Make the initial prompt cache for the model
|
||||
prompt_cache = make_prompt_cache(model)
|
||||
|
||||
# User turn
|
||||
prompt = "Hi my name is <Name>."
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Assistant response
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
verbose=True,
|
||||
max_tokens=1024,
|
||||
temp=0.0,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
||||
# User turn
|
||||
prompt = "What's my name?"
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
# Assistant response
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
verbose=True,
|
||||
max_tokens=1024,
|
||||
temp=0.0,
|
||||
prompt_cache=prompt_cache,
|
||||
)
|
@@ -1,3 +1,5 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from mlx_lm import generate, load
|
||||
|
||||
# Specify the checkpoint
|
||||
|
@@ -6,6 +6,7 @@ import sys
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .models.cache import load_prompt_cache
|
||||
from .utils import generate, load
|
||||
|
||||
DEFAULT_PROMPT = "hello"
|
||||
@@ -96,7 +97,7 @@ def setup_arg_parser():
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-cache-file",
|
||||
"--prompt-cache-file",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A file containing saved KV caches to avoid recomputing them",
|
||||
@@ -131,24 +132,6 @@ def colorprint_by_t0(s, t0):
|
||||
colorprint(color, s)
|
||||
|
||||
|
||||
def load_kv_cache_from_file(kv_cache_file):
|
||||
if kv_cache_file is None:
|
||||
return None, None
|
||||
|
||||
kv_cache, metadata = mx.load(kv_cache_file, return_metadata=True)
|
||||
cache_per_layer = {}
|
||||
for k, x in kv_cache.items():
|
||||
layer, kv_type = k.split("_")
|
||||
if layer not in cache_per_layer:
|
||||
cache_per_layer[layer] = {}
|
||||
cache_per_layer[layer][kv_type] = x
|
||||
|
||||
cache_history = [None] * len(cache_per_layer)
|
||||
for layer, c in cache_per_layer.items():
|
||||
cache_history[int(layer)] = (c["keys"], c["values"])
|
||||
return cache_history, metadata
|
||||
|
||||
|
||||
def main():
|
||||
parser = setup_arg_parser()
|
||||
args = parser.parse_args()
|
||||
@@ -158,22 +141,32 @@ def main():
|
||||
if args.cache_limit_gb is not None:
|
||||
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
|
||||
|
||||
# Load the kv cache and metadata if a kv cache file is provided
|
||||
cache_history, metadata = load_kv_cache_from_file(args.kv_cache_file)
|
||||
# Load the prompt cache and metadata if a cache file is provided
|
||||
using_cache = args.prompt_cache_file is not None
|
||||
if using_cache:
|
||||
prompt_cache, metadata = load_prompt_cache(
|
||||
args.prompt_cache_file, return_metadata=True
|
||||
)
|
||||
|
||||
# Building tokenizer_config
|
||||
tokenizer_config = (
|
||||
{} if cache_history is None else json.loads(metadata["tokenizer_config"])
|
||||
{} if not using_cache else json.loads(metadata["tokenizer_config"])
|
||||
)
|
||||
if args.trust_remote_code:
|
||||
tokenizer_config["trust_remote_code"] = True
|
||||
if args.eos_token is not None:
|
||||
tokenizer_config["eos_token"] = args.eos_token
|
||||
|
||||
# If no model path is provided then use the one in the kv cache history
|
||||
model_path = args.model
|
||||
if cache_history is not None and model_path is None:
|
||||
if using_cache:
|
||||
if model_path is None:
|
||||
model_path = metadata["model"]
|
||||
elif model_path != metadata["model"]:
|
||||
raise ValueError(
|
||||
f"Providing a different model ({model_path}) than that "
|
||||
f"used to create the prompt cache ({metadata['model']}) "
|
||||
"is an error."
|
||||
)
|
||||
|
||||
model, tokenizer = load(
|
||||
model_path,
|
||||
@@ -184,7 +177,7 @@ def main():
|
||||
if args.use_default_chat_template:
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = tokenizer.default_chat_template
|
||||
elif cache_history is not None:
|
||||
elif using_cache:
|
||||
tokenizer.chat_template = metadata["chat_template"]
|
||||
|
||||
if not args.ignore_chat_template and (
|
||||
@@ -203,7 +196,7 @@ def main():
|
||||
|
||||
# Treat the prompt as a suffix assuming that the prefix is in the
|
||||
# stored kv cache.
|
||||
if cache_history is not None:
|
||||
if using_cache:
|
||||
test_prompt = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": "<query>"}],
|
||||
tokenize=False,
|
||||
@@ -217,12 +210,6 @@ def main():
|
||||
raise ValueError("Cannot use --colorize with --verbose=False")
|
||||
formatter = colorprint_by_t0 if args.colorize else None
|
||||
|
||||
# Determine the max kv size from the kv cache or passed arguments
|
||||
max_kv_size = args.max_kv_size
|
||||
if cache_history is not None:
|
||||
max_kv_size = metadata["max_kv_size"]
|
||||
max_kv_size = int(max_kv_size) if max_kv_size.isdigit() else None
|
||||
|
||||
response = generate(
|
||||
model,
|
||||
tokenizer,
|
||||
@@ -232,8 +219,8 @@ def main():
|
||||
formatter=formatter,
|
||||
temp=args.temp,
|
||||
top_p=args.top_p,
|
||||
max_kv_size=max_kv_size,
|
||||
cache_history=cache_history,
|
||||
max_kv_size=args.max_kv_size,
|
||||
prompt_cache=prompt_cache if using_cache else None,
|
||||
)
|
||||
if not args.verbose:
|
||||
print(response)
|
||||
|
@@ -2,153 +2,9 @@
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class KVCache:
|
||||
|
||||
def __init__(self, head_dim, n_kv_heads):
|
||||
self.n_kv_heads = n_kv_heads
|
||||
if isinstance(head_dim, int):
|
||||
self.k_head_dim = self.v_head_dim = head_dim
|
||||
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
|
||||
self.k_head_dim, self.v_head_dim = head_dim
|
||||
else:
|
||||
raise ValueError("head_dim must be an int or a tuple of two ints")
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
prev = self.offset
|
||||
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
||||
B = keys.shape[0]
|
||||
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
||||
k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
|
||||
v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
if prev % self.step != 0:
|
||||
self.keys = self.keys[..., :prev, :]
|
||||
self.values = self.values[..., :prev, :]
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
|
||||
self.offset += keys.shape[2]
|
||||
self.keys[..., prev : self.offset, :] = keys
|
||||
self.values[..., prev : self.offset, :] = values
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.keys, self.values
|
||||
|
||||
|
||||
class RotatingKVCache:
|
||||
|
||||
def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
|
||||
self.n_kv_heads = n_kv_heads
|
||||
if isinstance(head_dim, int):
|
||||
self.k_head_dim = self.v_head_dim = head_dim
|
||||
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
|
||||
self.k_head_dim, self.v_head_dim = head_dim
|
||||
else:
|
||||
raise ValueError("head_dim must be an int or a tuple of two ints")
|
||||
self.keep = keep
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.max_size = max_size
|
||||
self.step = step
|
||||
self._idx = 0
|
||||
|
||||
def _trim(self, trim_size, v, append=None):
|
||||
to_cat = []
|
||||
if trim_size > 0:
|
||||
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
|
||||
else:
|
||||
to_cat = [v]
|
||||
if append is not None:
|
||||
to_cat.append(append)
|
||||
return mx.concatenate(to_cat, axis=2)
|
||||
|
||||
def _update_concat(self, keys, values):
|
||||
if self.keys is None:
|
||||
self.keys = keys
|
||||
self.values = values
|
||||
else:
|
||||
if self._idx < self.keys.shape[2]:
|
||||
self.keys = self.keys[..., : self._idx, :]
|
||||
self.values = self.values[..., : self._idx, :]
|
||||
|
||||
# The largest size is self.max_size + S - 1 to ensure
|
||||
# every token gets at least self.max_size context
|
||||
trim_size = self._idx - self.max_size + 1
|
||||
self.keys = self._trim(trim_size, self.keys, keys)
|
||||
self.values = self._trim(trim_size, self.values, values)
|
||||
self.offset += keys.shape[2]
|
||||
self._idx = self.keys.shape[2]
|
||||
return self.keys, self.values
|
||||
|
||||
def _update_in_place(self, keys, values):
|
||||
# May not have hit the max size yet, so potentially
|
||||
# keep growing the cache
|
||||
B, _, S = keys.shape[:3]
|
||||
prev = self.offset
|
||||
if self.keys is None or (
|
||||
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
|
||||
):
|
||||
new_size = min(self.step, self.max_size - prev)
|
||||
k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim)
|
||||
v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
self._idx = prev
|
||||
|
||||
# Trim if needed
|
||||
trim_size = self.keys.shape[2] - self.max_size
|
||||
if trim_size > 0:
|
||||
self.keys = self._trim(trim_size, self.keys)
|
||||
self.values = self._trim(trim_size, self.values)
|
||||
self._idx = self.max_size
|
||||
|
||||
# Rotate
|
||||
if self._idx == self.max_size:
|
||||
self._idx = self.keep
|
||||
|
||||
# Assign
|
||||
self.keys[..., self._idx : self._idx + S, :] = keys
|
||||
self.values[..., self._idx : self._idx + S, :] = values
|
||||
self.offset += S
|
||||
self._idx += S
|
||||
|
||||
# If the buffer is not full, slice off the end
|
||||
if self.offset < self.max_size:
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
return self.keys, self.values
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
S = keys.shape[2]
|
||||
if S == 1 or (self.keys is not None and S < (self.keys.shape[2] - self._idx)):
|
||||
return self._update_in_place(keys, values)
|
||||
|
||||
return self._update_concat(keys, values)
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.keys, self.values
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -164,25 +20,30 @@ class BaseModelArgs:
|
||||
)
|
||||
|
||||
|
||||
def create_additive_causal_mask(N: int, offset: int = 0):
|
||||
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
|
||||
rinds = mx.arange(offset + N)
|
||||
linds = mx.arange(offset, offset + N) if offset else rinds
|
||||
mask = linds[:, None] < rinds[None]
|
||||
linds = linds[:, None]
|
||||
rinds = rinds[None]
|
||||
mask = linds < rinds
|
||||
if window_size is not None:
|
||||
mask = mask | (linds > rinds + window_size)
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
||||
T = h.shape[1]
|
||||
if T > 1:
|
||||
window_size = None
|
||||
offset = 0
|
||||
if cache is not None and cache[0] is not None:
|
||||
c = cache[0]
|
||||
if isinstance(c, RotatingKVCache):
|
||||
if hasattr(c, "max_size"):
|
||||
offset = min(c.max_size - 1, c.offset)
|
||||
window_size = c.max_size
|
||||
else:
|
||||
offset = c.offset
|
||||
else:
|
||||
offset = 0
|
||||
mask = create_additive_causal_mask(T, offset)
|
||||
mask = create_causal_mask(T, offset, window_size=window_size)
|
||||
mask = mask.astype(h.dtype)
|
||||
else:
|
||||
mask = None
|
||||
|
257
llms/mlx_lm/models/cache.py
Normal file
257
llms/mlx_lm/models/cache.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
|
||||
if hasattr(model, "make_cache"):
|
||||
return model.make_cache()
|
||||
|
||||
num_layers = len(model.layers)
|
||||
if max_kv_size is not None:
|
||||
return [
|
||||
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
|
||||
]
|
||||
else:
|
||||
return [KVCache() for _ in range(num_layers)]
|
||||
|
||||
|
||||
def save_prompt_cache(
|
||||
file_name: str, cache: List[Any], metadata: Optional[Dict[str, str]] = None
|
||||
):
|
||||
"""
|
||||
Save a pre-computed prompt cache to a file.
|
||||
"""
|
||||
cache_data, cache_info = zip(*(c.state for c in cache))
|
||||
cache_data = dict(tree_flatten(cache_data))
|
||||
cache_classes = [type(c).__name__ for c in cache]
|
||||
cache_metadata = [cache_classes, cache_info]
|
||||
if metadata:
|
||||
cache_metadata.append(metadata)
|
||||
cache_metadata = dict(tree_flatten(cache_metadata))
|
||||
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
||||
|
||||
|
||||
def load_prompt_cache(file_name, return_metadata=False):
|
||||
"""
|
||||
Load a prompt cache from a file.
|
||||
Args:
|
||||
file_name (str): The ``.safetensors`` file name.
|
||||
return_metadata (bool): Whether or not to return metadata. Default:
|
||||
``False``.
|
||||
|
||||
Returns:
|
||||
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
|
||||
the metadata if requested.
|
||||
"""
|
||||
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
|
||||
arrays = tree_unflatten(list(arrays.items()))
|
||||
cache_metadata = tree_unflatten(list(cache_metadata.items()))
|
||||
classes, info = cache_metadata[:2]
|
||||
cache = [globals()[c]() for c in classes]
|
||||
for c, *state in zip(cache, arrays, info):
|
||||
c.state = state
|
||||
if return_metadata:
|
||||
return cache, cache_metadata[2]
|
||||
return cache
|
||||
|
||||
|
||||
class KVCache:
|
||||
|
||||
def __init__(self):
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.step = 256
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
prev = self.offset
|
||||
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
||||
B, n_kv_heads, _, k_head_dim = keys.shape
|
||||
v_head_dim = values.shape[3]
|
||||
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
||||
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
|
||||
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
if prev % self.step != 0:
|
||||
self.keys = self.keys[..., :prev, :]
|
||||
self.values = self.values[..., :prev, :]
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
|
||||
self.offset += keys.shape[2]
|
||||
self.keys[..., prev : self.offset, :] = keys
|
||||
self.values[..., prev : self.offset, :] = values
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
if self.offset == self.keys.shape[2]:
|
||||
return (self.keys, self.values), ""
|
||||
else:
|
||||
return (
|
||||
self.keys[..., : self.offset, :],
|
||||
self.values[..., : self.offset, :],
|
||||
), ""
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.keys, self.values = v[0]
|
||||
self.offset = self.keys.shape[2]
|
||||
|
||||
|
||||
class RotatingKVCache:
|
||||
|
||||
def __init__(self, max_size=None, keep=0, step=256):
|
||||
self.keep = keep
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.max_size = max_size
|
||||
self.step = step
|
||||
self._idx = 0
|
||||
|
||||
def _trim(self, trim_size, v, append=None):
|
||||
to_cat = []
|
||||
if trim_size > 0:
|
||||
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
|
||||
else:
|
||||
to_cat = [v]
|
||||
if append is not None:
|
||||
to_cat.append(append)
|
||||
return mx.concatenate(to_cat, axis=2)
|
||||
|
||||
def _temporal_order(self, v):
|
||||
"""
|
||||
Rearrange the cache into temporal order, slicing off the end if unused.
|
||||
"""
|
||||
if self._idx == v.shape[2]:
|
||||
return v
|
||||
elif self._idx < self.offset:
|
||||
return mx.concatenate(
|
||||
[
|
||||
v[..., : self.keep, :],
|
||||
v[..., self._idx :, :],
|
||||
v[..., self.keep : self._idx, :],
|
||||
],
|
||||
axis=2,
|
||||
)
|
||||
else:
|
||||
return v[..., : self._idx, :]
|
||||
|
||||
def _update_concat(self, keys, values):
|
||||
if self.keys is None:
|
||||
self.keys = keys
|
||||
self.values = values
|
||||
else:
|
||||
# Put the keys/values in temporal order to
|
||||
# preserve context
|
||||
self.keys = self._temporal_order(self.keys)
|
||||
self.values = self._temporal_order(self.values)
|
||||
|
||||
# The largest size is self.max_size + S - 1 to ensure
|
||||
# every token gets at least self.max_size context
|
||||
trim_size = self._idx - self.max_size + 1
|
||||
self.keys = self._trim(trim_size, self.keys, keys)
|
||||
self.values = self._trim(trim_size, self.values, values)
|
||||
self.offset += keys.shape[2]
|
||||
self._idx = self.keys.shape[2]
|
||||
return self.keys, self.values
|
||||
|
||||
def _update_in_place(self, keys, values):
|
||||
# May not have hit the max size yet, so potentially
|
||||
# keep growing the cache
|
||||
B, n_kv_heads, S, k_head_dim = keys.shape
|
||||
prev = self.offset
|
||||
if self.keys is None or (
|
||||
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
|
||||
):
|
||||
v_head_dim = values.shape[3]
|
||||
new_size = min(self.step, self.max_size - prev)
|
||||
k_shape = (B, n_kv_heads, new_size, k_head_dim)
|
||||
v_shape = (B, n_kv_heads, new_size, v_head_dim)
|
||||
new_k = mx.zeros(k_shape, keys.dtype)
|
||||
new_v = mx.zeros(v_shape, values.dtype)
|
||||
if self.keys is not None:
|
||||
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
||||
self.values = mx.concatenate([self.values, new_v], axis=2)
|
||||
else:
|
||||
self.keys, self.values = new_k, new_v
|
||||
self._idx = prev
|
||||
|
||||
# Trim if needed
|
||||
trim_size = self.keys.shape[2] - self.max_size
|
||||
if trim_size > 0:
|
||||
self.keys = self._trim(trim_size, self.keys)
|
||||
self.values = self._trim(trim_size, self.values)
|
||||
self._idx = self.max_size
|
||||
|
||||
# Rotate
|
||||
if self._idx == self.max_size:
|
||||
self._idx = self.keep
|
||||
|
||||
# Assign
|
||||
self.keys[..., self._idx : self._idx + S, :] = keys
|
||||
self.values[..., self._idx : self._idx + S, :] = values
|
||||
self.offset += S
|
||||
self._idx += S
|
||||
|
||||
# If the buffer is not full, slice off the end
|
||||
if self.offset < self.max_size:
|
||||
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
||||
return self.keys, self.values
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
if keys.shape[2] == 1:
|
||||
return self._update_in_place(keys, values)
|
||||
return self._update_concat(keys, values)
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
if self.offset < self.keys.shape[2]:
|
||||
kv_state = (self.keys[..., : self.offset], self.values[..., : self.offset])
|
||||
else:
|
||||
kv_state = (self.keys, self.values)
|
||||
extra_state = tuple(
|
||||
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
|
||||
)
|
||||
return kv_state, extra_state
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.keys, self.values = v[0]
|
||||
self.keep, self.max_size, self.step, self.offset, self._idx = map(
|
||||
int,
|
||||
v[1],
|
||||
)
|
||||
|
||||
|
||||
class MambaCache:
|
||||
def __init__(self):
|
||||
self.cache = [None, None]
|
||||
|
||||
def __setitem__(self, idx, value):
|
||||
self.cache[idx] = value
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.cache[idx]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.cache
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.cache, ""
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.cache = v[0]
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -69,7 +69,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -129,7 +129,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = self.input_layernorm(x)
|
||||
attn_h = self.self_attn(h, mask, cache)
|
||||
@@ -190,11 +190,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -49,7 +49,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
|
||||
qkv = self.Wqkv(x)
|
||||
@@ -92,7 +92,7 @@ class NormAttnNorm(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
|
||||
x = h + x
|
||||
@@ -179,7 +179,7 @@ class DecoderLayer(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r, h = self.norm_attn_norm(x, mask, cache)
|
||||
out = self.ffn(h) + r
|
||||
@@ -249,11 +249,3 @@ class Model(nn.Module):
|
||||
experts = [(s, sv.T) for s, sv in experts]
|
||||
new_weights.update(experts)
|
||||
return new_weights
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.d_model // self.args.n_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.attn_config["kv_n_heads"]
|
||||
|
@@ -1,10 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ class DeepseekAttention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, _ = x.shape
|
||||
|
||||
@@ -188,7 +188,7 @@ class DeepseekDecoderLayer(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -210,7 +210,7 @@ class DeepseekModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
mask = create_attention_mask(h, cache)
|
||||
@@ -235,7 +235,7 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
@@ -256,11 +256,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -2,12 +2,12 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class ModelArgs(BaseModelArgs):
|
||||
max_position_embeddings: int = 2048
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 10000.0
|
||||
rope_scaling: Optional[Dict] = None
|
||||
rope_scaling: Dict = None
|
||||
attention_bias: bool = False
|
||||
|
||||
|
||||
@@ -172,7 +172,6 @@ class DeepseekV2Attention(nn.Module):
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
if self.config.rope_scaling is not None:
|
||||
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
|
||||
scaling_factor = self.config.rope_scaling["factor"]
|
||||
if mscale_all_dim:
|
||||
@@ -202,7 +201,7 @@ class DeepseekV2Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -347,7 +346,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -370,7 +369,7 @@ class DeepseekV2Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
mask = create_attention_mask(h, cache)
|
||||
@@ -395,7 +394,7 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
):
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
@@ -416,14 +415,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (
|
||||
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
|
||||
self.args.v_head_dim,
|
||||
)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -60,7 +60,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -113,7 +113,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -173,11 +173,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.head_dim
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -64,7 +64,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
@@ -135,13 +135,11 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x.astype(mx.float32)), mask, cache)
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + self.post_attention_layernorm(r)
|
||||
r = self.mlp(self.pre_feedforward_layernorm(h).astype(mx.float16)).astype(
|
||||
mx.float32
|
||||
)
|
||||
r = self.mlp(self.pre_feedforward_layernorm(h))
|
||||
out = h + self.post_feedforward_layernorm(r)
|
||||
return out
|
||||
|
||||
@@ -200,11 +198,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.head_dim
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -46,7 +46,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -100,7 +100,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.ln_1(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -196,11 +196,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.n_embd // self.args.n_head
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -57,7 +57,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -114,7 +114,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.ln_1(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -184,11 +184,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.n_embd // self.args.n_head
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -60,7 +60,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -120,7 +120,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
residual = x
|
||||
# NeoX runs attention and feedforward network in parallel.
|
||||
@@ -214,11 +214,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -116,7 +116,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -171,7 +171,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attention(self.attention_norm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -236,11 +236,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -1,12 +1,12 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -171,7 +171,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -233,7 +233,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -303,13 +303,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (
|
||||
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
|
||||
)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -7,6 +7,7 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .cache import MambaCache
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -45,21 +46,6 @@ class ModelArgs(BaseModelArgs):
|
||||
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
||||
|
||||
|
||||
class MambaCache:
|
||||
def __init__(self):
|
||||
self.cache = [None, None]
|
||||
|
||||
def __setitem__(self, idx, value):
|
||||
self.cache[idx] = value
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.cache[idx]
|
||||
|
||||
@property
|
||||
def state(self):
|
||||
return self.cache
|
||||
|
||||
|
||||
class DepthWiseConv1d(nn.Module):
|
||||
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
||||
super().__init__()
|
||||
@@ -223,7 +209,7 @@ class Model(nn.Module):
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
return weights
|
||||
|
||||
def make_cache(self, batch_size: int = 1):
|
||||
def make_cache(self):
|
||||
return [MambaCache() for _ in range(len(self.layers))]
|
||||
|
||||
@property
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -85,7 +85,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
|
||||
@@ -135,7 +135,7 @@ class DecoderLayer(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
|
||||
@@ -205,11 +205,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -2,7 +2,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -66,7 +66,7 @@ class MixtralAttention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -138,7 +138,7 @@ class MixtralDecoderLayer(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -215,11 +215,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -2,12 +2,12 @@
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Dict, Optional, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -94,7 +94,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, _ = x.shape
|
||||
|
||||
@@ -151,7 +151,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -215,13 +215,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return (
|
||||
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
|
||||
)
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -1,8 +1,8 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from sys import exit
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -13,7 +13,7 @@ try:
|
||||
import hf_olmo
|
||||
except ImportError:
|
||||
print("To run olmo install ai2-olmo: pip install ai2-olmo")
|
||||
exit(1)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -68,7 +68,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -98,7 +98,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attend(self.att_norm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -174,11 +174,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.transformer.blocks
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.d_model // self.args.n_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.n_heads
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -80,7 +80,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -152,7 +152,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.attn(self.attn_norm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -218,11 +218,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.head_dim
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_kv_heads
|
||||
|
@@ -162,19 +162,11 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
cache: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
y = self.model(x, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -1,12 +1,12 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .su_rope import SuScaledRotaryEmbedding
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -143,7 +143,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -202,11 +202,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -3,12 +3,12 @@
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -22,14 +22,14 @@ class ModelArgs(BaseModelArgs):
|
||||
num_attention_heads: int
|
||||
layer_norm_epsilon: float
|
||||
vocab_size: int
|
||||
num_key_value_heads: Optional[int] = None
|
||||
num_key_value_heads: int
|
||||
mup_attn_multiplier: float = 1.0
|
||||
mup_use_scaling: bool = True
|
||||
mup_embedding_multiplier: float = 10.0
|
||||
mup_width_multiplier: float = 8.0
|
||||
rope_embedding_base: float = 1000000
|
||||
rope_position_scale: float = 1.0
|
||||
blocksparse_block_size: Tuple[int] = (64,)
|
||||
blocksparse_block_size: int = 64
|
||||
blocksparse_num_local_blocks: int = 16
|
||||
blocksparse_vert_stride: int = 8
|
||||
|
||||
@@ -61,7 +61,6 @@ class Attention(nn.Module):
|
||||
|
||||
dim = args.hidden_size
|
||||
self.n_heads = n_heads = args.num_attention_heads
|
||||
assert args.num_key_value_heads is not None
|
||||
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
|
||||
self.n_q_per_kv = n_heads // n_kv_heads
|
||||
|
||||
@@ -161,7 +160,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -230,7 +229,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -304,16 +303,8 @@ class Model(nn.Module):
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
return {
|
||||
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
|
||||
}
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -173,6 +173,7 @@ class PhiMoEModel(nn.Module):
|
||||
class Model(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.model_type = args.model_type
|
||||
self.args = args
|
||||
self.model = PhiMoEModel(args)
|
||||
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True)
|
||||
@@ -208,11 +209,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -168,8 +168,8 @@ class Model(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
y = self.transformer(x, mask, cache)
|
||||
@@ -193,11 +193,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.model_dim // self.args.num_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_heads
|
||||
|
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -62,8 +62,8 @@ class Attention(nn.Module):
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
bsz, q_len, _ = hidden_states.shape
|
||||
|
||||
queries = self.q_proj(hidden_states)
|
||||
@@ -127,8 +127,8 @@ class PlamoDecoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: mx.array,
|
||||
attention_mask: Optional[mx.array] = None,
|
||||
cache: Optional[Tuple[mx.array, mx.array]] = None,
|
||||
) -> Tuple[Any, ...]:
|
||||
cache: Optional[Any] = None,
|
||||
):
|
||||
# from LlamaDecoder
|
||||
residual = hidden_states
|
||||
|
||||
@@ -169,8 +169,8 @@ class PlamoModel(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None,
|
||||
) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]:
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(inputs)
|
||||
|
||||
mask = create_attention_mask(h, cache)
|
||||
@@ -197,19 +197,11 @@ class Model(nn.Module):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: mx.array,
|
||||
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
out = self.model(inputs, cache)
|
||||
return self.lm_head(out)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_attention_heads // self.args.n_shared_head
|
||||
|
@@ -1,7 +1,6 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -149,19 +148,11 @@ class Model(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
y = self.transformer(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.transformer.h
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_attention_heads
|
||||
|
@@ -1,12 +1,12 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -70,7 +70,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -124,7 +124,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -196,11 +196,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -2,12 +2,12 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .switch_layers import SwitchGLU
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -162,7 +162,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -236,11 +236,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -7,13 +7,13 @@ from typing import List, Literal, Optional
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
from .cache import MambaCache, RotatingKVCache
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs(BaseModelArgs):
|
||||
model_type: str
|
||||
hidden_size: int
|
||||
attention_bias: bool
|
||||
conv1d_width: int
|
||||
hidden_size: int
|
||||
@@ -36,59 +36,6 @@ class ModelArgs(BaseModelArgs):
|
||||
self.block_types = self._block_types
|
||||
|
||||
|
||||
def create_window_causal_mask(N: int, window_size: int):
|
||||
inds = mx.arange(N)
|
||||
linds = inds[:, None]
|
||||
rinds = inds[None]
|
||||
mask = (linds < rinds) | (linds > rinds + window_size)
|
||||
return mask * -1e9
|
||||
|
||||
|
||||
class RecurrentCache:
|
||||
|
||||
def __init__(self):
|
||||
self._cache = (None, None)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self._cache[idx]
|
||||
|
||||
def update(self, conv_state, recurrent_state):
|
||||
self._cache = (conv_state, recurrent_state)
|
||||
|
||||
def state(self):
|
||||
return self._cache
|
||||
|
||||
|
||||
class WindowKVCache:
|
||||
|
||||
def __init__(self, window_size):
|
||||
self.keys = None
|
||||
self.values = None
|
||||
self.offset = 0
|
||||
self.window_size = window_size
|
||||
|
||||
def update_and_fetch(self, keys, values):
|
||||
# TODO consider using rotating buffer here
|
||||
# especially for very long generations
|
||||
def _update(x, v):
|
||||
t = x.shape[2] - self.window_size
|
||||
if t > 0:
|
||||
x = x[..., t:, :]
|
||||
return mx.concatenate([x, v], axis=2)
|
||||
|
||||
self.offset += keys.shape[2]
|
||||
if self.keys is None:
|
||||
self.keys = keys
|
||||
self.values = values
|
||||
else:
|
||||
self.keys = _update(self.keys, keys)
|
||||
self.values = _update(self.values, values)
|
||||
return self.keys, self.values
|
||||
|
||||
def state(self):
|
||||
return self.keys, self.values
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
@@ -136,31 +83,22 @@ class Conv1d(nn.Module):
|
||||
kernel_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.weight = mx.zeros((kernel_size, channels))
|
||||
self.weight = mx.zeros((channels, kernel_size, 1))
|
||||
self.bias = mx.zeros((channels,))
|
||||
|
||||
def __call__(self, x, cache=None):
|
||||
w = self.weight.T[..., None]
|
||||
kw, groups = self.weight.shape
|
||||
if cache is not None:
|
||||
l = []
|
||||
# Pad the cache if needed
|
||||
if cache.shape[1] < kw - 1:
|
||||
l.append(
|
||||
mx.zeros(
|
||||
(x.shape[0], kw - 1 - cache.shape[1], groups), dtype=x.dtype
|
||||
)
|
||||
)
|
||||
l.extend([cache, x])
|
||||
x = mx.concatenate(l, axis=1)
|
||||
y = (x * w.swapaxes(0, 2)).sum(axis=1, keepdims=True)
|
||||
else:
|
||||
y = mx.conv_general(x, w, padding=([kw - 1], [0]), groups=groups)
|
||||
B, L, C = x.shape
|
||||
groups, K, _ = self.weight.shape
|
||||
|
||||
# The cache is always kw - 1
|
||||
cache = x[:, max(x.shape[1] - kw + 1, 0) :, :]
|
||||
if cache is not None:
|
||||
x = mx.concatenate([cache, x], axis=1)
|
||||
else:
|
||||
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
||||
|
||||
y = mx.conv_general(x, self.weight, groups=groups)
|
||||
y = y + self.bias
|
||||
return y, cache
|
||||
|
||||
return y, x[:, -K + 1 :, :]
|
||||
|
||||
|
||||
class RGLRU(nn.Module):
|
||||
@@ -269,19 +207,9 @@ class RecurrentBlock(nn.Module):
|
||||
# x branch.
|
||||
x = self.linear_x(x)
|
||||
if cache is None:
|
||||
conv_state, recurrent_state = (None, None)
|
||||
else:
|
||||
conv_state, recurrent_state = cache[0], cache[1]
|
||||
x, conv_state = self.conv_1d(
|
||||
x=x,
|
||||
cache=conv_state,
|
||||
)
|
||||
x, recurrent_state = self.rg_lru(
|
||||
x=x,
|
||||
cache=recurrent_state,
|
||||
)
|
||||
if cache is not None:
|
||||
cache.update(conv_state, recurrent_state)
|
||||
cache = [None, None]
|
||||
x, cache[0] = self.conv_1d(x=x, cache=cache[0])
|
||||
x, cache[1] = self.rg_lru(x=x, cache=cache[1])
|
||||
|
||||
x = x * y
|
||||
x = self.linear_out(x)
|
||||
@@ -467,12 +395,10 @@ class Griffin(nn.Module):
|
||||
if self.scale_by_sqrt_dim:
|
||||
x = x * math.sqrt(x.shape[-1])
|
||||
|
||||
mask = None
|
||||
if x.shape[1] > 1:
|
||||
mask = create_window_causal_mask(
|
||||
x.shape[1], self.config.attention_window_size
|
||||
)
|
||||
mask = mask.astype(x.dtype)
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
mask = create_attention_mask(x, cache)
|
||||
|
||||
for i, block in enumerate(self.layers):
|
||||
x = block(x, mask=mask, cache=cache[i])
|
||||
@@ -485,6 +411,7 @@ class Model(nn.Module):
|
||||
def __init__(self, config):
|
||||
self.args = config
|
||||
self.model = Griffin(config)
|
||||
self.model_type = config.model_type
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
|
||||
@@ -508,10 +435,9 @@ class Model(nn.Module):
|
||||
return self.model.layers
|
||||
|
||||
def sanitize(self, weights):
|
||||
# Remove unused precomputed rotary freqs
|
||||
for k, v in weights.items():
|
||||
if "conv_1d.weight" in k and v.ndim == 3:
|
||||
weights[k] = v.squeeze(1).T
|
||||
weights[k] = v.moveaxis(2, 1)
|
||||
if "lm_head.weight" not in weights:
|
||||
self.pop("lm_head")
|
||||
return weights
|
||||
@@ -520,7 +446,7 @@ class Model(nn.Module):
|
||||
cache = []
|
||||
for layer in self.layers:
|
||||
if layer.temporal_block_type == "recurrent":
|
||||
cache.append(RecurrentCache())
|
||||
cache.append(MambaCache())
|
||||
else:
|
||||
cache.append(WindowKVCache(self.args.attention_window_size))
|
||||
cache.append(RotatingKVCache(max_size=self.args.attention_window_size))
|
||||
return cache
|
||||
|
@@ -2,7 +2,6 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@@ -198,8 +197,8 @@ class Model(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: mx.array = None,
|
||||
cache: mx.array = None,
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
cache=None,
|
||||
) -> mx.array:
|
||||
mask = create_attention_mask(x, cache)
|
||||
y = self.model(x, mask, cache)
|
||||
return self.lm_head(y)
|
||||
@@ -207,11 +206,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -1,12 +1,12 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
from typing import Any, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs, KVCache, create_attention_mask
|
||||
from .base import BaseModelArgs, create_attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -45,7 +45,7 @@ class Attention(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
B, L, D = x.shape
|
||||
|
||||
@@ -100,7 +100,7 @@ class TransformerBlock(nn.Module):
|
||||
self,
|
||||
x: mx.array,
|
||||
mask: Optional[mx.array] = None,
|
||||
cache: Optional[KVCache] = None,
|
||||
cache: Optional[Any] = None,
|
||||
) -> mx.array:
|
||||
r = self.self_attn(self.input_layernorm(x), mask, cache)
|
||||
h = x + r
|
||||
@@ -164,11 +164,3 @@ class Model(nn.Module):
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
|
||||
@property
|
||||
def head_dim(self):
|
||||
return self.args.hidden_size // self.args.num_attention_heads
|
||||
|
||||
@property
|
||||
def n_kv_heads(self):
|
||||
return self.args.num_key_value_heads
|
||||
|
@@ -18,7 +18,7 @@ from mlx.utils import tree_flatten
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
from .models.base import KVCache, RotatingKVCache
|
||||
from .models import base, cache
|
||||
from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling
|
||||
from .tokenizer_utils import TokenizerWrapper, load_tokenizer
|
||||
from .tuner.utils import dequantize as dequantize_model
|
||||
@@ -124,26 +124,6 @@ def apply_repetition_penalty(logits: mx.array, tokens: mx.array, penalty: float)
|
||||
return logits
|
||||
|
||||
|
||||
def make_kv_caches(
|
||||
model: nn.Module, max_kv_size: Optional[int] = None
|
||||
) -> List[Union[KVCache, RotatingKVCache]]:
|
||||
if hasattr(model, "make_cache"):
|
||||
return model.make_cache()
|
||||
|
||||
kv_heads = (
|
||||
[model.n_kv_heads] * len(model.layers)
|
||||
if isinstance(model.n_kv_heads, int)
|
||||
else model.n_kv_heads
|
||||
)
|
||||
if max_kv_size is not None:
|
||||
return [
|
||||
RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4)
|
||||
for n in kv_heads
|
||||
]
|
||||
else:
|
||||
return [KVCache(model.head_dim, n) for n in kv_heads]
|
||||
|
||||
|
||||
def generate_step(
|
||||
prompt: mx.array,
|
||||
model: nn.Module,
|
||||
@@ -155,7 +135,7 @@ def generate_step(
|
||||
min_tokens_to_keep: int = 1,
|
||||
prefill_step_size: int = 512,
|
||||
max_kv_size: Optional[int] = None,
|
||||
cache_history: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
prompt_cache: Optional[Any] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
logits_processor: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
@@ -180,6 +160,8 @@ def generate_step(
|
||||
prefill_step_size (int): Step size for processing the prompt.
|
||||
max_kv_size (int, optional): Maximum size of the key-value cache. Old
|
||||
entries (except the first 4 tokens) will be overwritten.
|
||||
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
|
||||
provided, the cache will be updated in place.
|
||||
logit_bias (dictionary, optional): Additive logit bias.
|
||||
logits_processor (List[Callable[[mx.array, mx.array], mx.array]], optional):
|
||||
A list of functions that take tokens and logits and return the processed
|
||||
@@ -237,20 +219,13 @@ def generate_step(
|
||||
tokens = None
|
||||
|
||||
# Create the KV cache for generation
|
||||
cache = make_kv_caches(model, max_kv_size)
|
||||
|
||||
if cache_history is not None:
|
||||
if len(cache_history) != len(cache):
|
||||
raise ValueError("Wrong number of layers in the cache history")
|
||||
|
||||
# Set the history in the cache objects and evaluate them to prepare for
|
||||
# generation.
|
||||
for c, h in zip(cache, cache_history):
|
||||
c.update_and_fetch(h[0], h[1])
|
||||
mx.eval([c.state for c in cache])
|
||||
if prompt_cache is None:
|
||||
prompt_cache = cache.make_prompt_cache(model, max_kv_size)
|
||||
elif len(prompt_cache) != len(model.layers):
|
||||
raise ValueError("Wrong number of layers in the prompt cache.")
|
||||
|
||||
def _step(y):
|
||||
logits = model(y[None], cache=cache)
|
||||
logits = model(y[None], cache=prompt_cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if logits_processor:
|
||||
@@ -265,7 +240,7 @@ def generate_step(
|
||||
|
||||
while y.size > prefill_step_size:
|
||||
model(y[:prefill_step_size][None], cache=cache)
|
||||
mx.eval([c.state for c in cache])
|
||||
mx.eval([c.state[0] for c in cache])
|
||||
y = y[prefill_step_size:]
|
||||
|
||||
y, logprobs = _step(y)
|
||||
@@ -305,9 +280,9 @@ def stream_generate(
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
detokenizer.reset()
|
||||
for (token, _), n in zip(
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
for n, (token, _) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
):
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
@@ -357,9 +332,9 @@ def generate(
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
|
||||
for (token, logprobs), n in zip(
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
for n, (token, logprobs) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
|
@@ -1,5 +1,4 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
@@ -11,7 +10,7 @@ from mlx_lm.utils import make_kv_caches
|
||||
class TestModels(unittest.TestCase):
|
||||
|
||||
def test_kv_cache(self):
|
||||
cache = KVCache(32, 4)
|
||||
cache = KVCache()
|
||||
|
||||
k = mx.ones((1, 4, 1, 32), mx.float16)
|
||||
v = mx.ones((1, 4, 1, 32), mx.float16)
|
||||
@@ -32,7 +31,7 @@ class TestModels(unittest.TestCase):
|
||||
|
||||
def test_rotating_kv_cache(self):
|
||||
b, h, d = 1, 2, 32
|
||||
cache = RotatingKVCache(d, h, max_size=8, step=4)
|
||||
cache = RotatingKVCache(max_size=8, step=4)
|
||||
|
||||
k = mx.random.uniform(shape=(b, h, 2, d))
|
||||
v = mx.random.uniform(shape=(b, h, 2, d))
|
||||
@@ -65,7 +64,7 @@ class TestModels(unittest.TestCase):
|
||||
idx %= 8
|
||||
|
||||
# Try with nonzero keep
|
||||
cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2)
|
||||
cache = RotatingKVCache(max_size=8, step=4, keep=2)
|
||||
|
||||
# Check a large update
|
||||
k = mx.random.uniform(shape=(b, h, 20, d))
|
||||
@@ -93,7 +92,7 @@ class TestModels(unittest.TestCase):
|
||||
# alternating prompt/prefill with generation
|
||||
d = 4
|
||||
h = 2
|
||||
cache = RotatingKVCache(d, h, max_size=18, step=4)
|
||||
cache = RotatingKVCache(max_size=18, step=4)
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 8, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
@@ -589,6 +588,179 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_deepseek(self):
|
||||
from mlx_lm.models import deepseek
|
||||
|
||||
args = deepseek.ModelArgs(
|
||||
model_type="deepseek",
|
||||
vocab_size=1024,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
moe_intermediate_size=256,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=8,
|
||||
num_key_value_heads=4,
|
||||
)
|
||||
model = deepseek.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_deepseek_v2(self):
|
||||
from mlx_lm.models import deepseek_v2
|
||||
|
||||
args = deepseek_v2.ModelArgs(
|
||||
model_type="deepseek_v2",
|
||||
vocab_size=1024,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
moe_intermediate_size=256,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
kv_lora_rank=4,
|
||||
q_lora_rank=4,
|
||||
qk_rope_head_dim=32,
|
||||
v_head_dim=16,
|
||||
qk_nope_head_dim=32,
|
||||
rope_scaling={
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1,
|
||||
"factor": 40,
|
||||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"type": "yarn",
|
||||
},
|
||||
)
|
||||
model = deepseek_v2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gemma2(self):
|
||||
from mlx_lm.models import gemma2
|
||||
|
||||
args = gemma2.ModelArgs(
|
||||
model_type="gemma2",
|
||||
hidden_size=128,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=2,
|
||||
head_dim=32,
|
||||
rms_norm_eps=1e-4,
|
||||
vocab_size=1024,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
model = gemma2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gpt_bigcode(self):
|
||||
from mlx_lm.models import gpt_bigcode
|
||||
|
||||
args = gpt_bigcode.ModelArgs(
|
||||
model_type="gpt_bigcode",
|
||||
n_embd=128,
|
||||
n_layer=128,
|
||||
n_inner=256,
|
||||
n_head=4,
|
||||
n_positions=1000,
|
||||
layer_norm_epsilon=1e-5,
|
||||
vocab_size=1024,
|
||||
)
|
||||
model = gpt_bigcode.Model(args)
|
||||
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
|
||||
|
||||
def test_nemotron(self):
|
||||
from mlx_lm.models import nemotron
|
||||
|
||||
args = nemotron.ModelArgs(
|
||||
model_type="nemotron",
|
||||
hidden_size=128,
|
||||
hidden_act="gelu",
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
norm_eps=1e-5,
|
||||
vocab_size=1024,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
model = nemotron.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_phi3small(self):
|
||||
from mlx_lm.models import phi3small
|
||||
|
||||
args = phi3small.ModelArgs(
|
||||
model_type="phi3small",
|
||||
hidden_size=128,
|
||||
dense_attention_every_n_layers=2,
|
||||
ff_intermediate_size=256,
|
||||
gegelu_limit=1.0,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
layer_norm_epsilon=1e-4,
|
||||
vocab_size=1000,
|
||||
)
|
||||
model = phi3small.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_phimoe(self):
|
||||
from mlx_lm.models import phimoe
|
||||
|
||||
args = phimoe.ModelArgs(
|
||||
model_type="phimoe",
|
||||
vocab_size=320,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
rope_scaling={
|
||||
"long_factor": [1.0] * 16,
|
||||
"long_mscale": 1.243163121016122,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"short_factor": [1.0] * 16,
|
||||
"short_mscale": 1.243163121016122,
|
||||
"type": "longrope",
|
||||
},
|
||||
)
|
||||
model = phimoe.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_recurrent_gemma(self):
|
||||
from mlx_lm.models import recurrent_gemma
|
||||
|
||||
args = recurrent_gemma.ModelArgs(
|
||||
model_type="recurrent_gemma",
|
||||
hidden_size=128,
|
||||
attention_bias=False,
|
||||
conv1d_width=3,
|
||||
intermediate_size=256,
|
||||
logits_soft_cap=1.0,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=4,
|
||||
num_key_value_heads=2,
|
||||
rms_norm_eps=1e-4,
|
||||
rope_theta=1000,
|
||||
attention_window_size=1024,
|
||||
vocab_size=1000,
|
||||
block_types=["recurrent", "recurrent", "attention"],
|
||||
)
|
||||
model = recurrent_gemma.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
143
llms/tests/test_prompt_cache.py
Normal file
143
llms/tests/test_prompt_cache.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx_lm.models.cache import (
|
||||
KVCache,
|
||||
MambaCache,
|
||||
RotatingKVCache,
|
||||
load_prompt_cache,
|
||||
make_prompt_cache,
|
||||
save_prompt_cache,
|
||||
)
|
||||
from mlx_lm.utils import generate_step, load
|
||||
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
|
||||
|
||||
class TestPromptCache(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def test_save_load(self):
|
||||
cache = [KVCache() for _ in range(4)]
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||
c.update_and_fetch(x, x)
|
||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||
save_prompt_cache(cache_file, cache)
|
||||
loaded_cache = load_prompt_cache(cache_file)
|
||||
self.assertTrue(len(cache), len(loaded_cache))
|
||||
for c, lc in zip(cache, loaded_cache):
|
||||
self.assertEqual(c.offset, lc.offset)
|
||||
self.assertTrue(mx.array_equal(c.state[0][0], lc.state[0][0]))
|
||||
self.assertTrue(mx.array_equal(c.state[0][1], lc.state[0][1]))
|
||||
|
||||
# Test with metadata
|
||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||
metadata = {"a": "b", "c": "d"}
|
||||
save_prompt_cache(cache_file, cache, metadata)
|
||||
_, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True)
|
||||
self.assertEqual(metadata, loaded_metadata)
|
||||
|
||||
def test_save_load_rotating_cache(self):
|
||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||
|
||||
# Test with rotating cache
|
||||
cache = [RotatingKVCache(max_size=8, keep=2) for _ in range(4)]
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 10, 4))
|
||||
c.update_and_fetch(x, x)
|
||||
|
||||
save_prompt_cache(cache_file, cache)
|
||||
loaded_cache = load_prompt_cache(cache_file)
|
||||
self.assertTrue(len(cache), len(loaded_cache))
|
||||
for c, lc in zip(cache, loaded_cache):
|
||||
self.assertEqual(c.offset, lc.offset)
|
||||
self.assertEqual(c.keep, lc.keep)
|
||||
self.assertEqual(c.max_size, lc.max_size)
|
||||
self.assertEqual(c.step, lc.step)
|
||||
self.assertTrue(mx.array_equal(c.state[0][0], lc.state[0][0]))
|
||||
self.assertTrue(mx.array_equal(c.state[0][1], lc.state[0][1]))
|
||||
|
||||
# Do a couple single token updates to get a rotation
|
||||
for _ in range(2):
|
||||
for c in cache:
|
||||
x = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||
c.update_and_fetch(x, x)
|
||||
|
||||
save_prompt_cache(cache_file, cache)
|
||||
loaded_cache = load_prompt_cache(cache_file)
|
||||
|
||||
for c, lc in zip(cache, loaded_cache):
|
||||
x = mx.random.uniform(shape=(1, 8, 1, 4))
|
||||
k, v = c.update_and_fetch(x, x)
|
||||
lk, lv = lc.update_and_fetch(x, x)
|
||||
self.assertEqual(c.offset, lc.offset)
|
||||
self.assertTrue(mx.array_equal(k, lk))
|
||||
self.assertTrue(mx.array_equal(v, lv))
|
||||
|
||||
def test_save_load_mixed_cache(self):
|
||||
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
|
||||
|
||||
cache = [MambaCache(), KVCache(), RotatingKVCache(8), MambaCache()]
|
||||
for c in cache:
|
||||
if isinstance(c, MambaCache):
|
||||
c[0] = mx.random.uniform(shape=(4, 4, 4))
|
||||
c[1] = mx.random.uniform(shape=(4, 4, 4))
|
||||
else:
|
||||
x = mx.random.uniform(shape=(4, 4, 7, 4))
|
||||
y = mx.random.uniform(shape=(4, 4, 7, 4))
|
||||
c.update_and_fetch(x, y)
|
||||
|
||||
save_prompt_cache(cache_file, cache)
|
||||
loaded_cache = load_prompt_cache(cache_file)
|
||||
for c, lc in zip(cache, loaded_cache):
|
||||
if isinstance(c, MambaCache):
|
||||
self.assertTrue(mx.array_equal(c[0], lc[0]))
|
||||
self.assertTrue(mx.array_equal(c[1], lc[1]))
|
||||
else:
|
||||
x = mx.random.uniform(shape=(4, 4, 1, 4))
|
||||
y = mx.random.uniform(shape=(4, 4, 1, 4))
|
||||
k, v = c.update_and_fetch(x, y)
|
||||
lk, lv = lc.update_and_fetch(x, y)
|
||||
self.assertEqual(c.offset, lc.offset)
|
||||
self.assertTrue(mx.array_equal(k, lk))
|
||||
self.assertTrue(mx.array_equal(v, lv))
|
||||
|
||||
def test_cache_with_generate(self):
|
||||
model, tokenizer = load(HF_MODEL_PATH)
|
||||
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||
results = zip(range(4), generate_step(prompt, model))
|
||||
toks, all_logits = zip(*(r[1] for r in results))
|
||||
|
||||
prompt_cache = make_prompt_cache(model)
|
||||
i = 0
|
||||
for _, (tok, logits) in zip(
|
||||
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
|
||||
):
|
||||
self.assertEqual(tok, toks[i])
|
||||
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||
i += 1
|
||||
|
||||
for _, (tok, logits) in zip(
|
||||
range(1),
|
||||
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
|
||||
):
|
||||
i += 1
|
||||
self.assertEqual(tok, toks[i])
|
||||
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user