More cache improvements (#1015)

* fix rotating kv cache for chat use case

* reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat

* nit in chat

* fix tests

* fix tests

* fix tests

* docs

* chat command

* comments + docs

* Define meta_state on all Cache implementations

* fixes + trim_prompt_cache api

* fix default model

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-10-07 20:45:51 -07:00
committed by GitHub
parent 9bc53fc210
commit fca087be49
43 changed files with 1151 additions and 691 deletions

View File

@@ -2,145 +2,9 @@
import inspect
from dataclasses import dataclass
from typing import Any, List, Optional
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
class KVCache:
def __init__(self, head_dim, n_kv_heads):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keys = None
self.values = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B = keys.shape[0]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@property
def state(self):
return self.keys, self.values
class RotatingKVCache:
def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keep = keep
self.keys = None
self.values = None
self.offset = 0
self.max_size = max_size
self.step = step
self._idx = 0
def _trim(self, trim_size, v, append=None):
to_cat = []
if trim_size > 0:
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
else:
to_cat = [v]
if append is not None:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def update_and_fetch(self, keys, values):
prev = self.offset
B, _, S = keys.shape[:3]
# Prefill mode
if S > 1:
if self.keys is None:
self.keys = keys
self.values = values
else:
# The largest size is self.max_size + S - 1 to ensure
# every token gets at least self.max_size context
trim_size = self.keys.shape[2] - self.max_size + 1
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += S
self._idx = self.keys.shape[2]
return self.keys, self.values
# Generation mode
# May not have hit the max size yet, so potentially
# keep growing the cache
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
):
new_size = min(self.step, self.max_size - prev)
k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim)
v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self._idx = prev
# Trim if needed
trim_size = self.keys.shape[2] - self.max_size
if trim_size > 0:
self.keys = self._trim(trim_size, self.keys)
self.values = self._trim(trim_size, self.values)
self._idx = self.max_size
# Rotate
if self._idx == self.max_size:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + 1, :] = keys
self.values[..., self._idx : self._idx + 1, :] = values
self.offset += 1
self._idx += 1
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
@property
def state(self):
return self.keys, self.values
@dataclass
@@ -156,25 +20,30 @@ class BaseModelArgs:
)
def create_additive_causal_mask(N: int, offset: int = 0):
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
mask = linds[:, None] < rinds[None]
linds = linds[:, None]
rinds = rinds[None]
mask = linds < rinds
if window_size is not None:
mask = mask | (linds > rinds + window_size)
return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
T = h.shape[1]
if T > 1:
window_size = None
offset = 0
if cache is not None and cache[0] is not None:
c = cache[0]
if isinstance(c, RotatingKVCache):
if hasattr(c, "max_size"):
offset = min(c.max_size - 1, c.offset)
window_size = c.max_size
else:
offset = c.offset
else:
offset = 0
mask = create_additive_causal_mask(T, offset)
mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else:
mask = None

333
llms/mlx_lm/models/cache.py Normal file
View File

@@ -0,0 +1,333 @@
# Copyright © 2023-2024 Apple Inc.
from typing import Any, Dict, List, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten, tree_unflatten
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
"""
Construct the model's cache for use when cgeneration.
This function will defer the cache construction to the model if it has a
``make_cache`` method, otherwise it will make a default KV cache.
Args:
model (nn.Module): The language model.
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
"""
if hasattr(model, "make_cache"):
return model.make_cache()
num_layers = len(model.layers)
if max_kv_size is not None:
return [
RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers)
]
else:
return [KVCache() for _ in range(num_layers)]
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
"""
Save a pre-computed prompt cache to a file.
Args:
file_name (str): The ``.safetensors`` file name.
cache (List[Any]): The model state.
metadata (Dict[str, str]): Optional metadata to save along with model
state.
"""
cache_data = [c.state for c in cache]
cache_info = [c.meta_state for c in cache]
cache_data = dict(tree_flatten(cache_data))
cache_classes = [type(c).__name__ for c in cache]
cache_metadata = [cache_info, metadata, cache_classes]
cache_metadata = dict(tree_flatten(cache_metadata))
mx.save_safetensors(file_name, cache_data, cache_metadata)
def load_prompt_cache(file_name, return_metadata=False):
"""
Load a prompt cache from a file.
Args:
file_name (str): The ``.safetensors`` file name.
return_metadata (bool): Whether or not to return metadata.
Default: ``False``.
Returns:
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
the metadata if requested.
"""
arrays, cache_metadata = mx.load(file_name, return_metadata=True)
arrays = tree_unflatten(list(arrays.items()))
cache_metadata = tree_unflatten(list(cache_metadata.items()))
info, metadata, classes = cache_metadata
cache = [globals()[c]() for c in classes]
for c, state, meta_state in zip(cache, arrays, info):
c.state = state
c.meta_state = meta_state
if return_metadata:
return cache, metadata
return cache
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
"""
Trim the model's cache by the given number of tokens.
This function will trim the cache if possible (in-place) and return the
number of tokens that were trimmed.
Args:
cache (List[Any]): The model's cache.
num_tokens (int): The number of tokens to trim.
Returns:
(int): The number of tokens that were trimmed.
"""
if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
return 0
return [c.trim(num_tokens) for c in cache][0]
class _BaseCache:
@property
def state(self):
return []
@state.setter
def state(self, v):
if v is not None and v:
raise ValueError("This cache has no state but a state was set.")
@property
def meta_state(self):
return ""
@meta_state.setter
def meta_state(self, v):
if v is not None and v:
raise ValueError("This cache has no meta_state but a meta_state was set.")
def is_trimmable(self):
return False
class KVCache(_BaseCache):
def __init__(self):
self.keys = None
self.values = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B, n_kv_heads, _, k_head_dim = keys.shape
v_head_dim = values.shape[3]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@property
def state(self):
if self.offset == self.keys.shape[2]:
return self.keys, self.values
else:
return (
self.keys[..., : self.offset, :],
self.values[..., : self.offset, :],
)
@state.setter
def state(self, v):
self.keys, self.values = v
self.offset = self.keys.shape[2]
def is_trimmable(self):
return True
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
return n
class RotatingKVCache(_BaseCache):
def __init__(self, max_size=None, keep=0, step=256):
self.keep = keep
self.keys = None
self.values = None
self.offset = 0
self.max_size = max_size
self.step = step
self._idx = 0
def _trim(self, trim_size, v, append=None):
to_cat = []
if trim_size > 0:
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
else:
to_cat = [v]
if append is not None:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def _temporal_order(self, v):
"""
Rearrange the cache into temporal order, slicing off the end if unused.
"""
if self._idx == v.shape[2]:
return v
elif self._idx < self.offset:
return mx.concatenate(
[
v[..., : self.keep, :],
v[..., self._idx :, :],
v[..., self.keep : self._idx, :],
],
axis=2,
)
else:
return v[..., : self._idx, :]
def _update_concat(self, keys, values):
if self.keys is None:
self.keys = keys
self.values = values
else:
# Put the keys/values in temporal order to
# preserve context
self.keys = self._temporal_order(self.keys)
self.values = self._temporal_order(self.values)
# The largest size is self.max_size + S - 1 to ensure
# every token gets at least self.max_size context
trim_size = self._idx - self.max_size + 1
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += keys.shape[2]
self._idx = self.keys.shape[2]
return self.keys, self.values
def _update_in_place(self, keys, values):
# May not have hit the max size yet, so potentially
# keep growing the cache
B, n_kv_heads, S, k_head_dim = keys.shape
prev = self.offset
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
):
v_head_dim = values.shape[3]
new_size = min(self.step, self.max_size - prev)
k_shape = (B, n_kv_heads, new_size, k_head_dim)
v_shape = (B, n_kv_heads, new_size, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self._idx = prev
# Trim if needed
trim_size = self.keys.shape[2] - self.max_size
if trim_size > 0:
self.keys = self._trim(trim_size, self.keys)
self.values = self._trim(trim_size, self.values)
self._idx = self.max_size
# Rotate
if self._idx == self.max_size:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + S, :] = keys
self.values[..., self._idx : self._idx + S, :] = values
self.offset += S
self._idx += S
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
def update_and_fetch(self, keys, values):
if keys.shape[2] == 1:
return self._update_in_place(keys, values)
return self._update_concat(keys, values)
@property
def state(self):
if self.offset < self.keys.shape[2]:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
else:
return self.keys, self.values
@state.setter
def state(self, v):
self.keys, self.values = v
@property
def meta_state(self):
return tuple(
map(str, (self.keep, self.max_size, self.step, self.offset, self._idx))
)
@meta_state.setter
def meta_state(self, v):
self.keep, self.max_size, self.step, self.offset, self._idx = map(
int,
v,
)
def is_trimmable(self):
return self.offset < self.max_size
def trim(self, n):
n = min(self.offset, n)
self.offset -= n
self._idx -= n
return n
class MambaCache(_BaseCache):
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
@state.setter
def state(self, v):
self.cache = v

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -69,7 +69,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -129,7 +129,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.input_layernorm(x)
attn_h = self.self_attn(h, mask, cache)
@@ -190,11 +190,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -49,7 +49,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
qkv = self.Wqkv(x)
@@ -92,7 +92,7 @@ class NormAttnNorm(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.attn(self.norm_1(x), mask=mask, cache=cache)
x = h + x
@@ -179,7 +179,7 @@ class DecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r, h = self.norm_attn_norm(x, mask, cache)
out = self.ffn(h) + r
@@ -249,11 +249,3 @@ class Model(nn.Module):
experts = [(s, sv.T) for s, sv in experts]
new_weights.update(experts)
return new_weights
@property
def head_dim(self):
return self.args.d_model // self.args.n_heads
@property
def n_kv_heads(self):
return self.args.attn_config["kv_n_heads"]

View File

@@ -1,10 +1,10 @@
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Any, Dict, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@@ -77,7 +77,7 @@ class DeepseekAttention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
@@ -108,8 +108,8 @@ class DeepseekMLP(nn.Module):
def __init__(
self,
config: ModelArgs,
hidden_size: int | None = None,
intermediate_size: int | None = None,
hidden_size: Optional[int] = None,
intermediate_size: Optional[int] = None,
):
super().__init__()
self.config = config
@@ -188,7 +188,7 @@ class DeepseekDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -210,7 +210,7 @@ class DeepseekModel(nn.Module):
def __call__(
self,
x: mx.array,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
@@ -235,7 +235,7 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
@@ -256,11 +256,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -2,12 +2,12 @@
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@@ -38,7 +38,7 @@ class ModelArgs(BaseModelArgs):
max_position_embeddings: int = 2048
rms_norm_eps: float = 1e-6
rope_theta: float = 10000.0
rope_scaling: Optional[Dict] = None
rope_scaling: Dict = None
attention_bias: bool = False
@@ -172,12 +172,11 @@ class DeepseekV2Attention(nn.Module):
bias=config.attention_bias,
)
if self.config.rope_scaling is not None:
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0)
scaling_factor = self.config.rope_scaling["factor"]
if mscale_all_dim:
mscale = yarn_get_mscale(scaling_factor, mscale_all_dim)
self.scale = self.scale * mscale * mscale
rope_kwargs = {
key: self.config.rope_scaling[key]
@@ -202,7 +201,7 @@ class DeepseekV2Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -347,7 +346,7 @@ class DeepseekV2DecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -370,7 +369,7 @@ class DeepseekV2Model(nn.Module):
def __call__(
self,
x: mx.array,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(x)
mask = create_attention_mask(h, cache)
@@ -395,7 +394,7 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
):
out = self.model(inputs, cache)
return self.lm_head(out)
@@ -416,14 +415,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return (
self.args.qk_nope_head_dim + self.args.qk_rope_head_dim,
self.args.v_head_dim,
)
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -60,7 +60,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -113,7 +113,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -173,11 +173,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -64,7 +64,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
@@ -135,13 +135,11 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x.astype(mx.float32)), mask, cache)
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + self.post_attention_layernorm(r)
r = self.mlp(self.pre_feedforward_layernorm(h).astype(mx.float16)).astype(
mx.float32
)
r = self.mlp(self.pre_feedforward_layernorm(h))
out = h + self.post_feedforward_layernorm(r)
return out
@@ -200,11 +198,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -46,7 +46,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -100,7 +100,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
@@ -196,11 +196,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.h
@property
def head_dim(self):
return self.args.n_embd // self.args.n_head
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -57,7 +57,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -114,7 +114,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.ln_1(x), mask, cache)
h = x + r
@@ -184,11 +184,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.n_embd // self.args.n_head
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -60,7 +60,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -120,7 +120,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
residual = x
# NeoX runs attention and feedforward network in parallel.
@@ -214,11 +214,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.h
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -116,7 +116,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -171,7 +171,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attention(self.attention_norm(x), mask, cache)
h = x + r
@@ -236,11 +236,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -171,7 +171,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -233,7 +233,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -303,13 +303,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return (
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
)
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -7,6 +7,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
@@ -45,21 +46,6 @@ class ModelArgs(BaseModelArgs):
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaCache:
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
@@ -223,7 +209,7 @@ class Model(nn.Module):
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self, batch_size: int = 1):
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -85,7 +85,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
):
B, L, _ = x.shape
@@ -135,7 +135,7 @@ class DecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r * (self.scale_depth / np.sqrt(self.num_hidden_layers))
@@ -205,11 +205,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -2,7 +2,7 @@
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -66,7 +66,7 @@ class MixtralAttention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -138,7 +138,7 @@ class MixtralDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -215,11 +215,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -2,12 +2,12 @@
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -94,7 +94,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, _ = x.shape
@@ -151,7 +151,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -215,13 +215,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return (
self.args.head_dim or self.args.hidden_size // self.args.num_attention_heads
)
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,8 +1,8 @@
# Copyright © 2023-2024 Apple Inc.
import sys
from dataclasses import dataclass
from sys import exit
from typing import Optional, Tuple
from typing import Any, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -13,7 +13,7 @@ try:
import hf_olmo
except ImportError:
print("To run olmo install ai2-olmo: pip install ai2-olmo")
exit(1)
sys.exit(1)
@dataclass
@@ -68,7 +68,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -98,7 +98,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attend(self.att_norm(x), mask, cache)
h = x + r
@@ -174,11 +174,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.transformer.blocks
@property
def head_dim(self):
return self.args.d_model // self.args.n_heads
@property
def n_kv_heads(self):
return self.args.n_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -80,7 +80,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -152,7 +152,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.attn(self.attn_norm(x), mask, cache)
h = x + r
@@ -218,11 +218,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.transformer.layers
@property
def head_dim(self):
return self.args.head_dim
@property
def n_kv_heads(self):
return self.args.num_kv_heads

View File

@@ -162,19 +162,11 @@ class Model(nn.Module):
def __call__(
self,
x: mx.array,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
cache=None,
) -> mx.array:
y = self.model(x, cache)
return self.lm_head(y)
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
from .su_rope import SuScaledRotaryEmbedding
@@ -84,7 +84,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -143,7 +143,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -202,11 +202,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -3,12 +3,12 @@
import math
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Tuple, Union
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -22,14 +22,14 @@ class ModelArgs(BaseModelArgs):
num_attention_heads: int
layer_norm_epsilon: float
vocab_size: int
num_key_value_heads: Optional[int] = None
num_key_value_heads: int
mup_attn_multiplier: float = 1.0
mup_use_scaling: bool = True
mup_embedding_multiplier: float = 10.0
mup_width_multiplier: float = 8.0
rope_embedding_base: float = 1000000
rope_position_scale: float = 1.0
blocksparse_block_size: Tuple[int] = (64,)
blocksparse_block_size: int = 64
blocksparse_num_local_blocks: int = 16
blocksparse_vert_stride: int = 8
@@ -61,7 +61,6 @@ class Attention(nn.Module):
dim = args.hidden_size
self.n_heads = n_heads = args.num_attention_heads
assert args.num_key_value_heads is not None
self.n_kv_heads = n_kv_heads = args.num_key_value_heads
self.n_q_per_kv = n_heads // n_kv_heads
@@ -161,7 +160,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -230,7 +229,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -304,16 +303,8 @@ class Model(nn.Module):
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
return {
k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k
}
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -173,6 +173,7 @@ class PhiMoEModel(nn.Module):
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.model_type = args.model_type
self.args = args
self.model = PhiMoEModel(args)
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=True)
@@ -208,11 +209,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -168,8 +168,8 @@ class Model(nn.Module):
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
y = self.transformer(x, mask, cache)
@@ -193,11 +193,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.model_dim // self.args.num_heads
@property
def n_kv_heads(self):
return self.args.num_heads

View File

@@ -1,7 +1,7 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Any, List, Optional, Tuple, Union
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
@@ -62,8 +62,8 @@ class Attention(nn.Module):
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
cache: Optional[Any] = None,
) -> mx.array:
bsz, q_len, _ = hidden_states.shape
queries = self.q_proj(hidden_states)
@@ -127,8 +127,8 @@ class PlamoDecoderLayer(nn.Module):
self,
hidden_states: mx.array,
attention_mask: Optional[mx.array] = None,
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[Any, ...]:
cache: Optional[Any] = None,
):
# from LlamaDecoder
residual = hidden_states
@@ -169,8 +169,8 @@ class PlamoModel(nn.Module):
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Union[Tuple[mx.array, mx.array], None]]] = None,
) -> Tuple[mx.array, Optional[List[Union[Tuple[mx.array, mx.array], None]]]]:
cache: Optional[Any] = None,
) -> mx.array:
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
@@ -197,19 +197,11 @@ class Model(nn.Module):
def __call__(
self,
inputs: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
) -> Tuple[mx.array, mx.array]:
cache: Optional[Any] = None,
) -> mx.array:
out = self.model(inputs, cache)
return self.lm_head(out)
@property
def layers(self):
return self.model.layers.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_attention_heads // self.args.n_shared_head

View File

@@ -1,7 +1,6 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -149,19 +148,11 @@ class Model(nn.Module):
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
cache=None,
) -> mx.array:
y = self.transformer(x, mask, cache)
return self.lm_head(y)
@property
def layers(self):
return self.transformer.h
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_attention_heads

View File

@@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -70,7 +70,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -124,7 +124,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -196,11 +196,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -2,12 +2,12 @@
import math
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
from .switch_layers import SwitchGLU
@@ -70,7 +70,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -162,7 +162,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -236,11 +236,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -7,13 +7,13 @@ from typing import List, Literal, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
from .cache import MambaCache, RotatingKVCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
attention_bias: bool
conv1d_width: int
hidden_size: int
@@ -36,59 +36,6 @@ class ModelArgs(BaseModelArgs):
self.block_types = self._block_types
def create_window_causal_mask(N: int, window_size: int):
inds = mx.arange(N)
linds = inds[:, None]
rinds = inds[None]
mask = (linds < rinds) | (linds > rinds + window_size)
return mask * -1e9
class RecurrentCache:
def __init__(self):
self._cache = (None, None)
def __getitem__(self, idx):
return self._cache[idx]
def update(self, conv_state, recurrent_state):
self._cache = (conv_state, recurrent_state)
def state(self):
return self._cache
class WindowKVCache:
def __init__(self, window_size):
self.keys = None
self.values = None
self.offset = 0
self.window_size = window_size
def update_and_fetch(self, keys, values):
# TODO consider using rotating buffer here
# especially for very long generations
def _update(x, v):
t = x.shape[2] - self.window_size
if t > 0:
x = x[..., t:, :]
return mx.concatenate([x, v], axis=2)
self.offset += keys.shape[2]
if self.keys is None:
self.keys = keys
self.values = values
else:
self.keys = _update(self.keys, keys)
self.values = _update(self.values, values)
return self.keys, self.values
def state(self):
return self.keys, self.values
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
@@ -136,31 +83,22 @@ class Conv1d(nn.Module):
kernel_size: int,
):
super().__init__()
self.weight = mx.zeros((kernel_size, channels))
self.weight = mx.zeros((channels, kernel_size, 1))
self.bias = mx.zeros((channels,))
def __call__(self, x, cache=None):
w = self.weight.T[..., None]
kw, groups = self.weight.shape
if cache is not None:
l = []
# Pad the cache if needed
if cache.shape[1] < kw - 1:
l.append(
mx.zeros(
(x.shape[0], kw - 1 - cache.shape[1], groups), dtype=x.dtype
)
)
l.extend([cache, x])
x = mx.concatenate(l, axis=1)
y = (x * w.swapaxes(0, 2)).sum(axis=1, keepdims=True)
else:
y = mx.conv_general(x, w, padding=([kw - 1], [0]), groups=groups)
B, L, C = x.shape
groups, K, _ = self.weight.shape
# The cache is always kw - 1
cache = x[:, max(x.shape[1] - kw + 1, 0) :, :]
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=groups)
y = y + self.bias
return y, cache
return y, x[:, -K + 1 :, :]
class RGLRU(nn.Module):
@@ -269,19 +207,9 @@ class RecurrentBlock(nn.Module):
# x branch.
x = self.linear_x(x)
if cache is None:
conv_state, recurrent_state = (None, None)
else:
conv_state, recurrent_state = cache[0], cache[1]
x, conv_state = self.conv_1d(
x=x,
cache=conv_state,
)
x, recurrent_state = self.rg_lru(
x=x,
cache=recurrent_state,
)
if cache is not None:
cache.update(conv_state, recurrent_state)
cache = [None, None]
x, cache[0] = self.conv_1d(x=x, cache=cache[0])
x, cache[1] = self.rg_lru(x=x, cache=cache[1])
x = x * y
x = self.linear_out(x)
@@ -467,12 +395,14 @@ class Griffin(nn.Module):
if self.scale_by_sqrt_dim:
x = x * math.sqrt(x.shape[-1])
mask = None
if x.shape[1] > 1:
mask = create_window_causal_mask(
x.shape[1], self.config.attention_window_size
)
mask = mask.astype(x.dtype)
if cache is None:
cache = [None] * len(self.layers)
for i, block in enumerate(self.layers):
if block.temporal_block_type != "recurrent":
mask_cache = [cache[i]]
mask = create_attention_mask(x, mask_cache)
for i, block in enumerate(self.layers):
x = block(x, mask=mask, cache=cache[i])
@@ -485,6 +415,7 @@ class Model(nn.Module):
def __init__(self, config):
self.args = config
self.model = Griffin(config)
self.model_type = config.model_type
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
@@ -508,10 +439,9 @@ class Model(nn.Module):
return self.model.layers
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
for k, v in weights.items():
if "conv_1d.weight" in k and v.ndim == 3:
weights[k] = v.squeeze(1).T
weights[k] = v.moveaxis(2, 1)
if "lm_head.weight" not in weights:
self.pop("lm_head")
return weights
@@ -520,7 +450,7 @@ class Model(nn.Module):
cache = []
for layer in self.layers:
if layer.temporal_block_type == "recurrent":
cache.append(RecurrentCache())
cache.append(MambaCache())
else:
cache.append(WindowKVCache(self.args.attention_window_size))
cache.append(RotatingKVCache(max_size=self.args.attention_window_size))
return cache

View File

@@ -2,7 +2,6 @@
import math
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
@@ -198,8 +197,8 @@ class Model(nn.Module):
self,
x: mx.array,
mask: mx.array = None,
cache: mx.array = None,
) -> Tuple[mx.array, mx.array]:
cache=None,
) -> mx.array:
mask = create_attention_mask(x, cache)
y = self.model(x, mask, cache)
return self.lm_head(y)
@@ -207,11 +206,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads

View File

@@ -1,12 +1,12 @@
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, KVCache, create_attention_mask
from .base import BaseModelArgs, create_attention_mask
@dataclass
@@ -45,7 +45,7 @@ class Attention(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
B, L, D = x.shape
@@ -100,7 +100,7 @@ class TransformerBlock(nn.Module):
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[KVCache] = None,
cache: Optional[Any] = None,
) -> mx.array:
r = self.self_attn(self.input_layernorm(x), mask, cache)
h = x + r
@@ -164,11 +164,3 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
@property
def head_dim(self):
return self.args.hidden_size // self.args.num_attention_heads
@property
def n_kv_heads(self):
return self.args.num_key_value_heads