mlx-examples/llms/mlx_lm/models/mamba2.py

326 lines
9.3 KiB
Python
Raw Normal View History

2024-10-02 18:48:15 +08:00
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
2024-10-20 22:11:39 +08:00
import mlx.core as mx
2024-10-21 00:04:34 +08:00
import mlx.nn as nn
2024-10-20 22:11:39 +08:00
2024-10-02 18:48:15 +08:00
from .base import BaseModelArgs
from .cache import MambaCache
2024-10-02 18:48:15 +08:00
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
2024-10-21 00:04:34 +08:00
initializer_range: float
residual_in_fp32: bool
rescale_prenorm_residual: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
2025-01-21 02:44:05 +08:00
time_step_limit: Tuple[float, float]
time_step_rank: Union[int, str]
time_step_min: float
time_step_max: float
time_step_floor: float
norm_before_gate: bool = True
2024-10-02 18:48:15 +08:00
def __post_init__(self):
2024-11-06 23:35:46 +08:00
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
2024-10-02 18:48:15 +08:00
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
2024-10-24 22:16:42 +08:00
2024-10-31 04:23:13 +08:00
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6, norm_before_gate=False):
2024-10-31 04:23:13 +08:00
super().__init__()
2024-11-06 23:35:46 +08:00
self.weight = mx.ones((hidden_size,))
2024-10-31 04:23:13 +08:00
self.variance_epsilon = eps
self.norm_before_gate = norm_before_gate
def rms_norm(self, x):
variance = mx.mean(x ** 2, axis=-1, keepdims=True)
x = x * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * x
def __call__(self, x, z=None):
if z is None:
return self.rms_norm(x)
if self.norm_before_gate:
x = self.rms_norm(x)
x = x * nn.silu(z)
else:
x = x * nn.silu(z)
x = self.rms_norm(x)
return x
2024-12-11 01:15:12 +08:00
2024-11-06 23:35:46 +08:00
def silu(x):
return x * mx.sigmoid(x)
2024-12-11 01:15:12 +08:00
class DepthWiseConv1d(nn.Module):
2024-11-22 05:25:58 +08:00
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
2024-11-22 05:25:58 +08:00
self.channels = channels
self.kernel_size = kernel_size
self.padding = padding
self.weight = mx.random.normal((channels, kernel_size, 1))
2024-11-22 05:25:58 +08:00
self.bias = mx.zeros((channels,)) if bias else None
2024-11-22 05:25:58 +08:00
def __call__(self, x, cache=None):
B, L, C = x.shape
_, K, _ = self.weight.shape
if cache is not None:
2024-11-22 05:25:58 +08:00
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=C)
y = y + self.bias
2024-11-22 05:25:58 +08:00
return y, x[:, -K + 1:, :]
def ssd_forward_attn(
x: mx.array,
dt: mx.array,
A: mx.array,
B: mx.array,
C: mx.array,
D: mx.array,
dt_bias: mx.array,
dt_min: float,
dt_max: float,
) -> Tuple[mx.array, mx.array]:
b, l, h, dh = x.shape
_, _, g, _ = B.shape
if dt_bias is not None:
dt = dt + dt_bias.reshape(1, 1, -1)
dt = nn.softplus(dt)
dt = mx.clip(dt, a_min=dt_min, a_max=dt_max)
B = mx.swapaxes(mx.swapaxes(B, 1, 3), 1, 2)
C = mx.swapaxes(C, 1, 2)
CB = C @ B
CB = mx.repeat(CB, repeats=h // g, axis=1)
dtA = dt * A.reshape(1, 1, -1)
dtA = mx.swapaxes(dtA, 1, 2)
decay = mx.exp(segsum(dtA))
surrogate_attention_matrix = mx.tril(CB * decay, 0)
dtx = dt.reshape(b, l, h, 1) * x
y = surrogate_attention_matrix @ dtx.swapaxes(1, 2)
y = mx.swapaxes(y, 1, 2)
decay = decay[:, :, -1, :].reshape(b, h, l).swapaxes(1, 2).reshape(b, l, h, 1)
B = mx.repeat(B, h // g, axis=1).swapaxes(2, 3)
dtxdecay = dtx * decay
dtxdecay = dtxdecay.swapaxes(1, 2).swapaxes(2, 3)
next_state = dtxdecay @ B
if D is not None:
y += x * D.reshape(1, 1, h, 1)
y = y.reshape(b, l, h * dh)
return y, next_state
def segsum(x):
l = x.shape[-1]
x = mx.repeat(x[..., None], l, axis=-1)
x = mx.tril(x, -1)
x_segsum = mx.cumsum(x, axis=-2)
return x_segsum
2024-11-06 23:35:46 +08:00
class Mamba2Block(nn.Module):
2024-10-21 00:04:34 +08:00
def __init__(self, args: ModelArgs):
2024-10-02 18:48:15 +08:00
super().__init__()
2024-11-06 23:35:46 +08:00
self.args = args
2024-11-10 21:36:26 +08:00
# Dimensions
self.d_model = args.hidden_size
self.d_state = args.state_size
self.d_conv = args.conv_kernel
self.expand = args.expand
self.d_inner = int(self.expand * self.d_model)
self.n_groups = args.n_groups
self.n_heads = args.num_heads
self.d_head = self.d_inner // self.n_heads
2025-01-21 02:44:05 +08:00
self.chunk_size = args.chunk_size
# Input projection
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
2025-01-21 01:26:21 +08:00
# Parameters
2025-01-21 02:44:05 +08:00
self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
# Convolution
self.conv1d = DepthWiseConv1d(
channels=self.d_inner + 2 * self.n_groups * self.d_state,
kernel_size=self.d_conv,
bias=args.use_conv_bias,
padding=self.d_conv-1
)
# Output projections
self.norm = MambaRMSNormGated(
self.d_inner,
eps=args.layer_norm_epsilon,
norm_before_gate=args.norm_before_gate
)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
def __call__(self, u: mx.array, cache=None):
batch_size, seq_len, _ = u.shape
2025-01-21 02:44:05 +08:00
if cache is None:
cache = [None, None]
2025-01-21 01:26:21 +08:00
# Project input
zxBCdt = self.in_proj(u)
# Split projections
2025-01-21 02:44:05 +08:00
z, xBC, dt = mx.split(
zxBCdt,
[self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state],
2025-01-21 02:44:05 +08:00
axis=-1
)
# Process convolution
xBC, conv_state = self.conv1d(xBC, cache[0])
2025-01-21 02:44:05 +08:00
xBC = silu(xBC)
2024-11-22 05:25:58 +08:00
if cache is not None:
cache[0] = conv_state
xBC = xBC[:, :seq_len, :]
# Split and reshape conv output
2025-01-21 02:44:05 +08:00
x, B, C = mx.split(
xBC,
[self.d_inner, self.d_inner + self.d_state * self.n_groups],
2025-01-21 02:44:05 +08:00
axis=-1
)
# Reshape for SSM processing
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
2025-01-21 02:44:05 +08:00
B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1))
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
2024-11-10 21:36:26 +08:00
# Get parameters for attention computation
A = -mx.exp(self.A_log)
# Compute parallel attention
y, next_state = ssd_forward_attn(
x=x,
dt=dt,
A=A,
B=B,
C=C,
D=self.D,
dt_bias=self.dt_bias,
dt_min=self.args.time_step_min,
dt_max=self.args.time_step_max,
)
# Update cache
if cache is not None:
cache[1] = next_state
# Apply normalization and output projection
y = self.norm(y, z)
y = self.out_proj(y)
2025-01-21 02:44:05 +08:00
return y
2024-10-02 18:48:15 +08:00
2024-10-21 00:04:34 +08:00
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
2024-10-02 18:48:15 +08:00
super().__init__()
2024-11-06 23:35:46 +08:00
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args)
2024-10-12 02:53:29 +08:00
self.norm = nn.RMSNorm(args.hidden_size)
2024-10-02 18:48:15 +08:00
2024-11-06 23:35:46 +08:00
def __call__(self, x: mx.array, cache):
# if self.residual_in_fp32:
# x = x.astype(mx.float32)
normed = self.norm(x)
output = self.mixer(normed, cache)
return output + x
2024-11-06 23:35:46 +08:00
2024-12-27 22:27:09 +08:00
2024-11-06 23:35:46 +08:00
class Mamba2(nn.Module):
2024-10-02 18:48:15 +08:00
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
2024-10-21 00:04:34 +08:00
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
2024-10-02 18:48:15 +08:00
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
2024-11-06 23:35:46 +08:00
def __call__(self, x: mx.array, cache):
2024-10-21 00:04:34 +08:00
x = self.embeddings(x)
2024-10-02 18:48:15 +08:00
if cache is None:
2024-10-21 00:04:34 +08:00
cache = [None] * len(self.layers)
hidden = x
2024-11-06 23:35:46 +08:00
for layer, c in zip(self.layers, cache):
hidden = layer(hidden, c)
return self.norm_f(hidden)
2024-10-02 18:48:15 +08:00
2024-11-06 23:35:46 +08:00
2024-10-02 18:48:15 +08:00
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
2024-11-06 23:35:46 +08:00
self.model_type = args.model_type
self.backbone = Mamba2(args)
2024-10-02 18:48:15 +08:00
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
2024-11-06 23:35:46 +08:00
def __call__(self, inputs: mx.array, cache=None):
hidden = self.backbone(inputs, cache)
2024-10-02 18:48:15 +08:00
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(hidden)
2024-10-02 18:48:15 +08:00
else:
logits = self.lm_head(hidden)
2024-10-02 18:48:15 +08:00
return logits
2024-11-22 05:25:58 +08:00
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
return weights
2024-11-06 23:35:46 +08:00
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
2024-11-06 23:35:46 +08:00
2024-10-31 04:23:13 +08:00
@property
def layers(self):
return self.backbone.layers