mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00

* fix rotating kv cache for chat use case * reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat * nit in chat * fix tests * fix tests * fix tests * docs * chat command * comments + docs * Define meta_state on all Cache implementations * fixes + trim_prompt_cache api * fix default model --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
457 lines
13 KiB
Python
457 lines
13 KiB
Python
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import List, Literal, Optional
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from .base import BaseModelArgs, create_attention_mask
|
|
from .cache import MambaCache, RotatingKVCache
|
|
|
|
|
|
@dataclass
|
|
class ModelArgs(BaseModelArgs):
|
|
model_type: str
|
|
attention_bias: bool
|
|
conv1d_width: int
|
|
hidden_size: int
|
|
intermediate_size: int
|
|
logits_soft_cap: float
|
|
num_attention_heads: int
|
|
num_hidden_layers: int
|
|
num_key_value_heads: int
|
|
rms_norm_eps: float
|
|
rope_theta: float
|
|
attention_window_size: int
|
|
vocab_size: int
|
|
embeddings_scale_by_sqrt_dim: bool = True
|
|
block_types: Optional[List[str]] = None
|
|
_block_types: Optional[List[str]] = None
|
|
|
|
def __post_init__(self):
|
|
# For some reason these have different names in 2B and 9B
|
|
if self.block_types is None:
|
|
self.block_types = self._block_types
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
def __init__(self, dims: int, eps: float = 1e-5):
|
|
super().__init__()
|
|
self.weight = mx.ones((dims,))
|
|
self.eps = eps
|
|
|
|
def __call__(self, x):
|
|
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
|
|
|
|
|
def rnn_scan(x, a, h0):
|
|
assert x.ndim == 3
|
|
assert a.shape == x.shape[-a.ndim :]
|
|
assert a.dtype == x.dtype
|
|
|
|
if x.shape[1] == 1:
|
|
# Using scan in sampling mode.
|
|
if h0 is None:
|
|
return x, x[:, 0]
|
|
|
|
else:
|
|
y = a * h0[:, None] + x
|
|
return y, y[:, -1]
|
|
|
|
else:
|
|
# Using scan in linear mode.
|
|
if h0 is not None:
|
|
h_t = h0
|
|
else:
|
|
B, _, D = x.shape
|
|
h_t = mx.zeros((B, D), dtype=x.dtype)
|
|
|
|
y = mx.zeros_like(x)
|
|
for t in range(x.shape[1]):
|
|
h_t = a[:, t] * h_t + x[:, t]
|
|
y[:, t] = h_t
|
|
|
|
return y, h_t
|
|
|
|
|
|
class Conv1d(nn.Module):
|
|
def __init__(
|
|
self,
|
|
channels: int,
|
|
kernel_size: int,
|
|
):
|
|
super().__init__()
|
|
self.weight = mx.zeros((channels, kernel_size, 1))
|
|
self.bias = mx.zeros((channels,))
|
|
|
|
def __call__(self, x, cache=None):
|
|
B, L, C = x.shape
|
|
groups, K, _ = self.weight.shape
|
|
|
|
if cache is not None:
|
|
x = mx.concatenate([cache, x], axis=1)
|
|
else:
|
|
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
|
|
|
y = mx.conv_general(x, self.weight, groups=groups)
|
|
y = y + self.bias
|
|
|
|
return y, x[:, -K + 1 :, :]
|
|
|
|
|
|
class RGLRU(nn.Module):
|
|
"""A Real-Gated Linear Recurrent Unit (RG-LRU) layer."""
|
|
|
|
def __init__(
|
|
self,
|
|
width: int,
|
|
num_heads: int,
|
|
):
|
|
super().__init__()
|
|
self.width = width
|
|
self.num_heads = num_heads
|
|
self.head_dim = self.width // self.num_heads
|
|
|
|
self.recurrent_param = mx.zeros((self.width,))
|
|
|
|
self.input_gate_weight = mx.zeros(
|
|
(self.num_heads, self.head_dim, self.head_dim),
|
|
)
|
|
self.input_gate_bias = mx.zeros((self.num_heads, self.head_dim))
|
|
|
|
self.recurrent_gate_weight = mx.zeros(
|
|
(self.num_heads, self.head_dim, self.head_dim),
|
|
)
|
|
self.recurrent_gate_bias = mx.zeros((self.num_heads, self.head_dim))
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
cache=None,
|
|
):
|
|
B, L, _ = x.shape
|
|
|
|
def apply_block_linear(h, w, b):
|
|
h = h.reshape((B, L, self.num_heads, self.head_dim))
|
|
h = (h.swapaxes(1, 2) @ w).swapaxes(1, 2) + b
|
|
return mx.sigmoid(h.flatten(2, 3))
|
|
|
|
# Gates for x and a.
|
|
gate_x = apply_block_linear(x, self.input_gate_weight, self.input_gate_bias)
|
|
gate_a = apply_block_linear(
|
|
x, self.recurrent_gate_weight, self.recurrent_gate_bias
|
|
)
|
|
|
|
# Compute the parameter `A` of the recurrence.
|
|
log_a = -8.0 * gate_a * nn.softplus(self.recurrent_param)
|
|
a = mx.exp(log_a)
|
|
a_square = mx.exp(2 * log_a)
|
|
|
|
# Gate the input.
|
|
gated_x = x * gate_x
|
|
|
|
# Apply gamma normalization to the input.
|
|
multiplier = mx.sqrt(1 - a_square)
|
|
if cache is None:
|
|
multiplier[:, 0, :] = 1.0
|
|
normalized_x = gated_x * multiplier.astype(x.dtype)
|
|
|
|
y, last_h = rnn_scan(
|
|
x=normalized_x,
|
|
a=a,
|
|
h0=cache,
|
|
)
|
|
|
|
return y, last_h
|
|
|
|
|
|
class RecurrentBlock(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
width: int,
|
|
num_heads: int,
|
|
lru_width: int = None,
|
|
conv1d_temporal_width: int = 4,
|
|
):
|
|
super().__init__()
|
|
self.width = width
|
|
self.num_heads = num_heads
|
|
self.lru_width = lru_width or width
|
|
self.conv1d_temporal_width = conv1d_temporal_width
|
|
|
|
self.linear_y = nn.Linear(width, self.lru_width)
|
|
self.linear_x = nn.Linear(width, self.lru_width)
|
|
self.linear_out = nn.Linear(self.lru_width, width)
|
|
self.conv_1d = Conv1d(
|
|
channels=self.lru_width,
|
|
kernel_size=self.conv1d_temporal_width,
|
|
)
|
|
self.rg_lru = RGLRU(
|
|
width=self.lru_width,
|
|
num_heads=self.num_heads,
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
cache=None,
|
|
mask=None,
|
|
):
|
|
# y branch.
|
|
y = self.linear_y(x)
|
|
y = nn.gelu_approx(y)
|
|
|
|
# x branch.
|
|
x = self.linear_x(x)
|
|
if cache is None:
|
|
cache = [None, None]
|
|
x, cache[0] = self.conv_1d(x=x, cache=cache[0])
|
|
x, cache[1] = self.rg_lru(x=x, cache=cache[1])
|
|
|
|
x = x * y
|
|
x = self.linear_out(x)
|
|
|
|
return x
|
|
|
|
|
|
class LocalAttentionBlock(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
width: int,
|
|
num_heads: int,
|
|
window_size: int,
|
|
):
|
|
super().__init__()
|
|
self.width = width
|
|
self.num_heads = num_heads
|
|
self.window_size = window_size
|
|
self.scale = (width // num_heads) ** (-0.5)
|
|
|
|
self.head_dim = self.width // self.num_heads
|
|
self.q_proj = nn.Linear(self.width, self.width, bias=False)
|
|
self.k_proj = nn.Linear(self.width, self.head_dim, bias=False)
|
|
self.v_proj = nn.Linear(self.width, self.head_dim, bias=False)
|
|
self.o_proj = nn.Linear(self.width, self.width, bias=True)
|
|
self.rope = nn.RoPE(
|
|
self.head_dim // 2,
|
|
traditional=False,
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
cache=None,
|
|
mask=None,
|
|
):
|
|
B, L, D = x.shape
|
|
|
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
|
|
|
queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3)
|
|
keys = keys.reshape(B, L, 1, -1).transpose(0, 2, 1, 3)
|
|
values = values.reshape(B, L, 1, -1).transpose(0, 2, 1, 3)
|
|
|
|
if cache is not None:
|
|
queries = self.rope(queries, offset=cache.offset)
|
|
keys = self.rope(keys, offset=cache.offset)
|
|
keys, values = cache.update_and_fetch(keys, values)
|
|
else:
|
|
queries = self.rope(queries)
|
|
keys = self.rope(keys)
|
|
|
|
output = mx.fast.scaled_dot_product_attention(
|
|
queries, keys, values, scale=self.scale, mask=mask
|
|
)
|
|
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
|
return self.o_proj(output)
|
|
|
|
|
|
class MLPBlock(nn.Module):
|
|
|
|
def __init__(self, width: int, expanded_width: int):
|
|
super().__init__()
|
|
self.up_proj = nn.Linear(width, expanded_width // 2)
|
|
self.gate_proj = nn.Linear(width, expanded_width // 2)
|
|
self.down_proj = nn.Linear(expanded_width // 2, width)
|
|
|
|
def __call__(self, x: mx.array):
|
|
gate = self.gate_proj(x)
|
|
x = self.up_proj(x)
|
|
return self.down_proj(nn.gelu_approx(gate) * x)
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
width: int,
|
|
mlp_expanded_width: int,
|
|
num_heads: int,
|
|
attention_window_size: int,
|
|
temporal_block_type: str,
|
|
lru_width: Optional[int] = None,
|
|
conv1d_temporal_width: int = 4,
|
|
):
|
|
"""Initializes the residual block.
|
|
|
|
Args:
|
|
width: The width of the block.
|
|
mlp_expanded_width: The width of the expansion inside the MLP block.
|
|
num_heads: The number of heads for the Attention or the RG-LRU.
|
|
attention_window_size: The window size for the local attention block.
|
|
temporal_block_type: Either "recurrent" or "attention", specifying the
|
|
type of recurrent block to use.
|
|
lru_width: The width of the RG-LRU if different from `width`.
|
|
conv1d_temporal_width: The width of the temporal convolution.
|
|
"""
|
|
super().__init__()
|
|
self.width = width
|
|
self.mlp_expanded_width = mlp_expanded_width
|
|
self.num_heads = num_heads
|
|
self.attention_window_size = attention_window_size
|
|
self.temporal_block_type = temporal_block_type
|
|
self.lru_width = lru_width
|
|
self.conv1d_temporal_width = conv1d_temporal_width
|
|
|
|
self.temporal_pre_norm = RMSNorm(width)
|
|
if self.temporal_block_type == "recurrent":
|
|
self.temporal_block = RecurrentBlock(
|
|
width=self.width,
|
|
num_heads=self.num_heads,
|
|
lru_width=self.lru_width,
|
|
conv1d_temporal_width=self.conv1d_temporal_width,
|
|
)
|
|
|
|
else:
|
|
self.temporal_block = LocalAttentionBlock(
|
|
width=self.width,
|
|
num_heads=self.num_heads,
|
|
window_size=self.attention_window_size,
|
|
)
|
|
|
|
self.channel_pre_norm = RMSNorm(width)
|
|
self.mlp_block = MLPBlock(
|
|
width=self.width,
|
|
expanded_width=self.mlp_expanded_width,
|
|
)
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
cache=None,
|
|
mask=None,
|
|
):
|
|
raw_x = x
|
|
|
|
inputs_normalized = self.temporal_pre_norm(raw_x)
|
|
|
|
x = self.temporal_block(inputs_normalized, cache=cache, mask=mask)
|
|
residual = x + raw_x
|
|
|
|
x = self.channel_pre_norm(residual)
|
|
x = self.mlp_block(x)
|
|
|
|
x = x + residual
|
|
|
|
return x
|
|
|
|
|
|
class Griffin(nn.Module):
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
self.config = config
|
|
self.embed_tokens = nn.Embedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
)
|
|
|
|
self.scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim
|
|
block_types = config.block_types
|
|
|
|
self.layers = [
|
|
ResidualBlock(
|
|
width=config.hidden_size,
|
|
mlp_expanded_width=config.intermediate_size,
|
|
num_heads=config.num_attention_heads,
|
|
attention_window_size=config.attention_window_size,
|
|
temporal_block_type=block_types[i % len(block_types)],
|
|
lru_width=None,
|
|
)
|
|
for i in range(config.num_hidden_layers)
|
|
]
|
|
self.final_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def __call__(
|
|
self,
|
|
tokens,
|
|
cache=None,
|
|
):
|
|
x = self.embed_tokens(tokens)
|
|
if self.scale_by_sqrt_dim:
|
|
x = x * math.sqrt(x.shape[-1])
|
|
|
|
if cache is None:
|
|
cache = [None] * len(self.layers)
|
|
|
|
for i, block in enumerate(self.layers):
|
|
if block.temporal_block_type != "recurrent":
|
|
mask_cache = [cache[i]]
|
|
|
|
mask = create_attention_mask(x, mask_cache)
|
|
|
|
for i, block in enumerate(self.layers):
|
|
x = block(x, mask=mask, cache=cache[i])
|
|
|
|
return self.final_norm(x)
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
self.args = config
|
|
self.model = Griffin(config)
|
|
self.model_type = config.model_type
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
|
|
|
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
|
|
"""
|
|
Args:
|
|
tokens: Sequence of input tokens.
|
|
"""
|
|
logits = self.model(tokens, cache=cache)
|
|
if "lm_head" in self:
|
|
logits = self.lm_head(logits)
|
|
else:
|
|
logits = self.model.embed_tokens.as_linear(logits)
|
|
|
|
c = self.args.logits_soft_cap
|
|
if c:
|
|
logits = mx.tanh(logits / c) * c
|
|
return logits
|
|
|
|
@property
|
|
def layers(self):
|
|
return self.model.layers
|
|
|
|
def sanitize(self, weights):
|
|
for k, v in weights.items():
|
|
if "conv_1d.weight" in k and v.ndim == 3:
|
|
weights[k] = v.moveaxis(2, 1)
|
|
if "lm_head.weight" not in weights:
|
|
self.pop("lm_head")
|
|
return weights
|
|
|
|
def make_cache(self):
|
|
cache = []
|
|
for layer in self.layers:
|
|
if layer.temporal_block_type == "recurrent":
|
|
cache.append(MambaCache())
|
|
else:
|
|
cache.append(RotatingKVCache(max_size=self.args.attention_window_size))
|
|
return cache
|