From 20e221f7f7e3e7e3b21205af39e6b8bf5cf57a40 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 7 Jul 2024 12:10:04 -0700 Subject: [PATCH] Add recurrent gemma (#856) * add recurrent gemma * fix window cache --- llms/mlx_lm/models/recurrent_gemma.py | 505 ++++++++++++++++++++++++++ llms/mlx_lm/utils.py | 15 +- 2 files changed, 514 insertions(+), 6 deletions(-) create mode 100644 llms/mlx_lm/models/recurrent_gemma.py diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py new file mode 100644 index 00000000..1988abd4 --- /dev/null +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -0,0 +1,505 @@ +import math +from dataclasses import dataclass +from typing import List, Literal, Optional + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + attention_bias: bool + conv1d_width: int + embeddings_scale_by_sqrt_dim: bool + hidden_size: int + intermediate_size: int + logits_soft_cap: float + num_attention_heads: int + num_hidden_layers: int + num_key_value_heads: int + rms_norm_eps: float + rope_theta: float + attention_window_size: int + vocab_size: int + _block_types: List[str] + + +def create_window_causal_mask(N: int, window_size: int): + inds = mx.arange(N) + linds = inds[:, None] + rinds = inds[None] + mask = (linds < rinds) | (linds > rinds + window_size) + return mask * -1e9 + + +class RecurrentCache: + + def __init__(self): + self._cache = (None, None) + + def __getitem__(self, idx): + return self._cache[idx] + + def update(self, conv_state, recurrent_state): + self._cache = (conv_state, recurrent_state) + + +class WindowKVCache: + + def __init__(self, window_size): + self.keys = None + self.values = None + self.offset = 0 + self.window_size = window_size + + def update_and_fetch(self, keys, values): + # TODO consider using rotating buffer here + # especially for very long generations + def _update(x, v): + t = x.shape[2] - self.window_size + if t > 0: + x = x[..., t:, :] + return mx.concatenate([x, v], axis=2) + + self.offset += keys.shape[2] + if self.keys is None: + self.keys = keys + self.values = values + else: + self.keys = _update(self.keys, keys) + self.values = _update(self.values, values) + return self.keys, self.values + + +class RMSNorm(nn.Module): + def __init__(self, dims: int, eps: float = 1e-5): + super().__init__() + self.weight = mx.ones((dims,)) + self.eps = eps + + def __call__(self, x): + return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps) + + +def rnn_scan(x, a, h0): + assert x.ndim == 3 + assert a.shape == x.shape[-a.ndim :] + assert a.dtype == x.dtype + + if x.shape[1] == 1: + # Using scan in sampling mode. + if h0 is None: + return x, x[:, 0] + + else: + y = a * h0[:, None] + x + return y, y[:, -1] + + else: + # Using scan in linear mode. + if h0 is not None: + h_t = h0 + else: + B, _, D = x.shape + h_t = mx.zeros((B, D), dtype=x.dtype) + + y = mx.zeros_like(x) + for t in range(x.shape[1]): + h_t = a[:, t] * h_t + x[:, t] + y[:, t] = h_t + + return y, h_t + + +class Conv1d(nn.Module): + def __init__( + self, + channels: int, + kernel_size: int, + ): + super().__init__() + self.weight = mx.zeros((kernel_size, channels)) + self.bias = mx.zeros((channels,)) + + def __call__(self, x, cache=None): + w = self.weight.T[..., None] + kw, groups = self.weight.shape + if cache is not None: + l = [] + # Pad the cache if needed + if cache.shape[1] < kw - 1: + l.append( + mx.zeros( + (x.shape[0], kw - 1 - cache.shape[1], groups), dtype=x.dtype + ) + ) + l.extend([cache, x]) + x = mx.concatenate(l, axis=1) + y = (x * w.swapaxes(0, 2)).sum(axis=1, keepdims=True) + else: + y = mx.conv_general(x, w, padding=([kw - 1], [0]), groups=groups) + + # The cache is always kw - 1 + cache = x[:, max(x.shape[1] - kw + 1, 0) :, :] + y = y + self.bias + return y, cache + + +class RGLRU(nn.Module): + """A Real-Gated Linear Recurrent Unit (RG-LRU) layer.""" + + def __init__( + self, + width: int, + num_heads: int, + ): + super().__init__() + self.width = width + self.num_heads = num_heads + self.head_dim = self.width // self.num_heads + + self.recurrent_param = mx.zeros((self.width,)) + + self.input_gate_weight = mx.zeros( + (self.num_heads, self.head_dim, self.head_dim), + ) + self.input_gate_bias = mx.zeros((self.num_heads, self.head_dim)) + + self.recurrent_gate_weight = mx.zeros( + (self.num_heads, self.head_dim, self.head_dim), + ) + self.recurrent_gate_bias = mx.zeros((self.num_heads, self.head_dim)) + + def __call__( + self, + x: mx.array, + cache=None, + ): + B, L, _ = x.shape + + def apply_block_linear(h, w, b): + h = h.reshape((B, L, self.num_heads, self.head_dim)) + h = (h.swapaxes(1, 2) @ w).swapaxes(1, 2) + b + return mx.sigmoid(h.flatten(2, 3)) + + # Gates for x and a. + gate_x = apply_block_linear(x, self.input_gate_weight, self.input_gate_bias) + gate_a = apply_block_linear( + x, self.recurrent_gate_weight, self.recurrent_gate_bias + ) + + # Compute the parameter `A` of the recurrence. + log_a = -8.0 * gate_a * nn.softplus(self.recurrent_param) + a = mx.exp(log_a) + a_square = mx.exp(2 * log_a) + + # Gate the input. + gated_x = x * gate_x + + # Apply gamma normalization to the input. + multiplier = mx.sqrt(1 - a_square) + normalized_x = gated_x * multiplier.astype(x.dtype) + + y, last_h = rnn_scan( + x=normalized_x, + a=a, + h0=cache, + ) + + return y, last_h + + +class RecurrentBlock(nn.Module): + + def __init__( + self, + width: int, + num_heads: int, + lru_width: int = None, + conv1d_temporal_width: int = 4, + ): + super().__init__() + self.width = width + self.num_heads = num_heads + self.lru_width = lru_width or width + self.conv1d_temporal_width = conv1d_temporal_width + + self.linear_y = nn.Linear(width, self.lru_width) + self.linear_x = nn.Linear(width, self.lru_width) + self.linear_out = nn.Linear(self.lru_width, width) + self.conv_1d = Conv1d( + channels=self.lru_width, + kernel_size=self.conv1d_temporal_width, + ) + self.rg_lru = RGLRU( + width=self.lru_width, + num_heads=self.num_heads, + ) + + def __call__( + self, + x: mx.array, + cache=None, + mask=None, + ): + # y branch. + y = self.linear_y(x) + y = nn.gelu_approx(y) + + # x branch. + x = self.linear_x(x) + if cache is None: + conv_state, recurrent_state = (None, None) + else: + conv_state, recurrent_state = cache[0], cache[1] + x, conv_state = self.conv_1d( + x=x, + cache=conv_state, + ) + x, recurrent_state = self.rg_lru( + x=x, + cache=recurrent_state, + ) + if cache is not None: + cache.update(conv_state, recurrent_state) + + x = x * y + x = self.linear_out(x) + + return x + + +class LocalAttentionBlock(nn.Module): + + def __init__( + self, + width: int, + num_heads: int, + window_size: int, + ): + super().__init__() + self.width = width + self.num_heads = num_heads + self.window_size = window_size + self.scale = (width // num_heads) ** (-0.5) + + self.head_dim = self.width // self.num_heads + self.q_proj = nn.Linear(self.width, self.width, bias=False) + self.k_proj = nn.Linear(self.width, self.head_dim, bias=False) + self.v_proj = nn.Linear(self.width, self.head_dim, bias=False) + self.o_proj = nn.Linear(self.width, self.width, bias=True) + self.rope = nn.RoPE( + self.head_dim // 2, + traditional=False, + ) + + def __call__( + self, + x: mx.array, + cache=None, + mask=None, + ): + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + queries = queries.reshape(B, L, self.num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, 1, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, 1, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLPBlock(nn.Module): + + def __init__(self, width: int, expanded_width: int): + super().__init__() + self.up_proj = nn.Linear(width, expanded_width // 2) + self.gate_proj = nn.Linear(width, expanded_width // 2) + self.down_proj = nn.Linear(expanded_width // 2, width) + + def __call__(self, x: mx.array): + gate = self.gate_proj(x) + x = self.up_proj(x) + return self.down_proj(nn.gelu_approx(gate) * x) + + +class ResidualBlock(nn.Module): + + def __init__( + self, + width: int, + mlp_expanded_width: int, + num_heads: int, + attention_window_size: int, + temporal_block_type: str, + lru_width: Optional[int] = None, + conv1d_temporal_width: int = 4, + ): + """Initializes the residual block. + + Args: + width: The width of the block. + mlp_expanded_width: The width of the expansion inside the MLP block. + num_heads: The number of heads for the Attention or the RG-LRU. + attention_window_size: The window size for the local attention block. + temporal_block_type: Either "recurrent" or "attention", specifying the + type of recurrent block to use. + lru_width: The width of the RG-LRU if different from `width`. + conv1d_temporal_width: The width of the temporal convolution. + """ + super().__init__() + self.width = width + self.mlp_expanded_width = mlp_expanded_width + self.num_heads = num_heads + self.attention_window_size = attention_window_size + self.temporal_block_type = temporal_block_type + self.lru_width = lru_width + self.conv1d_temporal_width = conv1d_temporal_width + + self.temporal_pre_norm = RMSNorm(width) + if self.temporal_block_type == "recurrent": + self.temporal_block = RecurrentBlock( + width=self.width, + num_heads=self.num_heads, + lru_width=self.lru_width, + conv1d_temporal_width=self.conv1d_temporal_width, + ) + + else: + self.temporal_block = LocalAttentionBlock( + width=self.width, + num_heads=self.num_heads, + window_size=self.attention_window_size, + ) + + self.channel_pre_norm = RMSNorm(width) + self.mlp_block = MLPBlock( + width=self.width, + expanded_width=self.mlp_expanded_width, + ) + + def __call__( + self, + x: mx.array, + cache=None, + mask=None, + ): + raw_x = x + + inputs_normalized = self.temporal_pre_norm(raw_x) + x = self.temporal_block(inputs_normalized, cache=cache, mask=mask) + + residual = x + raw_x + + x = self.channel_pre_norm(residual) + x = self.mlp_block(x) + + x = x + residual + + return x + + +class Griffin(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.embed_tokens = nn.Embedding( + config.vocab_size, + config.hidden_size, + ) + + self.scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim + block_types = config._block_types + + self.layers = [ + ResidualBlock( + width=config.hidden_size, + mlp_expanded_width=config.intermediate_size, + num_heads=config.num_attention_heads, + attention_window_size=config.attention_window_size, + temporal_block_type=block_types[i % len(block_types)], + lru_width=None, + ) + for i in range(config.num_hidden_layers) + ] + self.final_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def __call__( + self, + tokens, + cache=None, + ): + x = self.embed_tokens(tokens) + if self.scale_by_sqrt_dim: + x = x * math.sqrt(x.shape[-1]) + + mask = None + if x.shape[1] > 1: + mask = create_window_causal_mask( + x.shape[1], self.config.attention_window_size + ) + mask = mask.astype(x.dtype) + + for i, block in enumerate(self.layers): + x = block(x, mask=mask, cache=cache[i]) + + x = self.final_norm(x) + logits = self.embed_tokens.as_linear(x) + + c = self.config.logits_soft_cap + if c: + logits = mx.tanh(logits / c) * c + + return logits + + +class Model(nn.Module): + + def __init__(self, config): + self.args = config + self.model = Griffin(config) + + def __call__(self, tokens: mx.array, cache=None) -> mx.array: + """ + Args: + tokens: Sequence of input tokens. + """ + return self.model(tokens, cache=cache) + + @property + def layers(self): + return self.model.layers + + def sanitize(self, weights): + # Remove unused precomputed rotary freqs + for k, v in weights.items(): + if "conv_1d.weight" in k and v.ndim == 3: + weights[k] = v.squeeze(1).T + return weights + + def make_cache(self): + cache = [] + for layer in self.layers: + if layer.temporal_block_type == "recurrent": + cache.append(RecurrentCache()) + else: + cache.append(WindowKVCache(self.args.attention_window_size)) + return cache diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index de9c877a..229ee238 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -181,12 +181,15 @@ def generate_step( ) y = prompt - kv_heads = ( - [model.n_kv_heads] * len(model.layers) - if isinstance(model.n_kv_heads, int) - else model.n_kv_heads - ) - cache = [KVCache(model.head_dim, n) for n in kv_heads] + if hasattr(model, "make_cache"): + cache = model.make_cache() + else: + kv_heads = ( + [model.n_kv_heads] * len(model.layers) + if isinstance(model.n_kv_heads, int) + else model.n_kv_heads + ) + cache = [KVCache(model.head_dim, n) for n in kv_heads] repetition_context = prompt.tolist()