mlx-examples/llms/mlx_lm/models/mamba2.py

301 lines
9.3 KiB
Python
Raw Normal View History

2024-10-02 18:48:15 +08:00
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
2024-10-20 22:11:39 +08:00
import mlx.core as mx
2024-10-21 00:04:34 +08:00
import mlx.nn as nn
2024-10-20 22:11:39 +08:00
2024-10-02 18:48:15 +08:00
from .base import BaseModelArgs
from .cache import MambaCache
2024-10-02 18:48:15 +08:00
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
2024-10-21 00:04:34 +08:00
initializer_range: float
residual_in_fp32: bool
chunk_size: int
tie_word_embeddings: bool
2025-01-21 02:44:05 +08:00
time_step_limit: Tuple[float, float]
time_step_rank: Union[int, str]
time_step_min: float
time_step_max: float
time_step_floor: float
norm_before_gate: bool = True
2024-10-02 18:48:15 +08:00
def __post_init__(self):
2024-11-06 23:35:46 +08:00
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
2024-10-02 18:48:15 +08:00
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
2024-10-24 22:16:42 +08:00
def ssd_forward_attn(
x: mx.array,
dt: mx.array,
A: mx.array,
B: mx.array,
C: mx.array,
D: mx.array,
dt_bias: mx.array,
dt_min: float,
dt_max: float,
2025-02-26 21:46:46 +08:00
prev_state=None,
) -> Tuple[mx.array, mx.array]:
b, l, h, dh = x.shape
_, _, g, _ = B.shape
2025-02-26 21:46:46 +08:00
# Process dt
if dt_bias is not None:
dt = dt + dt_bias.reshape(1, 1, -1)
dt = nn.softplus(dt)
dt = mx.clip(dt, a_min=dt_min, a_max=dt_max)
2025-02-26 21:46:46 +08:00
# Reshape tensors
B_reshaped = mx.swapaxes(mx.swapaxes(B, 1, 3), 1, 2)
C_reshaped = mx.swapaxes(C, 1, 2)
2025-02-26 21:46:46 +08:00
# Compute CB
CB = C_reshaped @ B_reshaped
CB = mx.repeat(CB, repeats=h // g, axis=1)
2025-02-26 21:46:46 +08:00
# Compute decay terms
dtA = dt * A.reshape(1, 1, -1)
dtA = mx.swapaxes(dtA, 1, 2)
decay = mx.exp(segsum(dtA))
2025-02-26 21:46:46 +08:00
# Create attention matrix
surrogate_attention_matrix = mx.tril(CB * decay, 0)
2025-02-26 21:46:46 +08:00
# Apply attention
dtx = dt.reshape(b, l, h, 1) * x
y = surrogate_attention_matrix @ dtx.swapaxes(1, 2)
y = mx.swapaxes(y, 1, 2)
2025-02-26 21:46:46 +08:00
# Compute next state
decay_last = decay[:, :, -1, :].reshape(b, h, l).swapaxes(1, 2).reshape(b, l, h, 1)
B_for_state = mx.repeat(B_reshaped, h // g, axis=1).swapaxes(2, 3)
dtxdecay = dtx * decay_last
dtxdecay = dtxdecay.swapaxes(1, 2).swapaxes(2, 3)
2025-02-26 21:46:46 +08:00
# Calculate new state contribution
new_state_contribution = dtxdecay @ B_for_state
# Initialize or update state
if prev_state is not None:
2025-02-26 22:16:45 +08:00
decayed_prev_state = prev_state * decay[:, :, -1, :].reshape(b, h, 1, 1)
next_state = decayed_prev_state + new_state_contribution
2025-02-26 21:46:46 +08:00
else:
next_state = new_state_contribution
# Add skip connection if D is provided
if D is not None:
y += x * D.reshape(1, 1, h, 1)
2025-02-26 21:46:46 +08:00
# Reshape output
y = y.reshape(b, l, h * dh)
return y, next_state
def segsum(x):
2025-02-26 21:46:46 +08:00
# x shape: [b, h, l]
b, h, l = x.shape
indices = mx.arange(l)
mask = indices[:, None] >= indices[None, :] # [l, l] lower triangular mask
# Expand x for broadcasting
x_expanded = x.reshape(b, h, l, 1) # [b, h, l, 1]
# Apply mask and sum
masked_x = x_expanded * mask.reshape(1, 1, l, l) # [b, h, l, l]
x_segsum = mx.sum(masked_x, axis=2, keepdims=True) # [b, h, 1, l]
return x_segsum
2024-11-06 23:35:46 +08:00
class Mamba2Block(nn.Module):
2024-10-21 00:04:34 +08:00
def __init__(self, args: ModelArgs):
2024-10-02 18:48:15 +08:00
super().__init__()
2024-11-06 23:35:46 +08:00
self.args = args
self.d_model = args.hidden_size
self.d_state = args.state_size
self.d_conv = args.conv_kernel
self.expand = args.expand
self.d_inner = int(self.expand * self.d_model)
self.n_groups = args.n_groups
self.n_heads = args.num_heads
self.d_head = self.d_inner // self.n_heads
2025-01-21 02:44:05 +08:00
self.chunk_size = args.chunk_size
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
2025-01-21 01:26:21 +08:00
2025-01-21 02:44:05 +08:00
self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
2025-02-26 22:16:45 +08:00
conv_channels = self.d_inner + 2 * self.n_groups * self.d_state
self.conv1d = nn.Conv1d(
in_channels=conv_channels,
out_channels=conv_channels,
kernel_size=self.d_conv,
2025-02-26 22:16:45 +08:00
groups=conv_channels,
padding=self.d_conv - 1,
)
self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
def __call__(self, u: mx.array, cache=None):
batch_size, seq_len, _ = u.shape
if cache is None:
cache = [None, None]
2025-01-23 05:30:15 +08:00
else:
conv_state, ssm_state = cache
zxBCdt = self.in_proj(u)
2025-01-21 02:44:05 +08:00
z, xBC, dt = mx.split(
2025-01-22 07:15:02 +08:00
zxBCdt,
[self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state],
2025-01-21 02:44:05 +08:00
axis=-1
)
2025-02-26 22:16:45 +08:00
# Handle convolution with caching
xBC = mx.swapaxes(xBC, 1, 2) # [B, L, C] -> [B, C, L]
if conv_state is not None and seq_len > 0:
# Concatenate cached state with current input
xBC_with_cache = mx.concatenate([conv_state, xBC], axis=2)
elif seq_len > 0:
# For the first call, pad with zeros
padding = mx.zeros((batch_size, xBC.shape[1], self.d_conv - 1))
xBC_with_cache = mx.concatenate([padding, xBC], axis=2)
else:
xBC_with_cache = conv_state if conv_state is not None else mx.zeros((batch_size, xBC.shape[1], 0))
# Save state for next iteration
if seq_len > 0:
next_conv_state = xBC_with_cache[:, :, -(self.d_conv - 1):]
else:
next_conv_state = conv_state
# Apply regular convolution using nn.Conv1d
if seq_len > 0:
# Use the standard Conv1d module for the actual computation
xBC_conv = self.conv1d(xBC_with_cache)
xBC = xBC_conv[:, :, -seq_len:] # Take only the relevant output positions
xBC = mx.swapaxes(xBC, 1, 2) # [B, C, L] -> [B, L, C]
xBC = xBC * mx.sigmoid(xBC)
else:
# Handle empty sequence case
xBC = mx.swapaxes(xBC, 1, 2) # [B, C, L] -> [B, L, C]
2025-01-21 02:44:05 +08:00
x, B, C = mx.split(
xBC,
[self.d_inner, self.d_inner + self.d_state * self.n_groups],
2025-01-21 02:44:05 +08:00
axis=-1
)
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
2025-01-21 02:44:05 +08:00
B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1))
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
2025-01-23 05:30:15 +08:00
y, next_ssm_state = ssd_forward_attn(
x=x,
dt=dt,
2025-02-26 21:46:46 +08:00
A=-mx.exp(self.A_log),
B=B,
C=C,
D=self.D,
dt_bias=self.dt_bias,
dt_min=self.args.time_step_min,
2025-02-26 21:46:46 +08:00
dt_max=self.args.time_step_max,
prev_state=ssm_state
)
if self.args.norm_before_gate:
y = self.norm(y)
y = y * nn.silu(z)
else:
y = y * nn.silu(z)
y = self.norm(y)
y = self.out_proj(y)
2025-02-26 22:16:45 +08:00
cache[0] = next_conv_state
2025-01-23 05:30:15 +08:00
cache[1] = next_ssm_state
return y
2024-10-02 18:48:15 +08:00
2024-10-21 00:04:34 +08:00
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
2024-10-02 18:48:15 +08:00
super().__init__()
2024-11-06 23:35:46 +08:00
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args)
2024-10-12 02:53:29 +08:00
self.norm = nn.RMSNorm(args.hidden_size)
2024-10-02 18:48:15 +08:00
2024-11-06 23:35:46 +08:00
def __call__(self, x: mx.array, cache):
if self.residual_in_fp32:
x = x.astype(mx.float32)
normed = self.norm(x)
output = self.mixer(normed, cache)
return output + x
2024-11-06 23:35:46 +08:00
2024-12-27 22:27:09 +08:00
2024-11-06 23:35:46 +08:00
class Mamba2(nn.Module):
2024-10-02 18:48:15 +08:00
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
2024-10-21 00:04:34 +08:00
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
2024-10-02 18:48:15 +08:00
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
2024-11-06 23:35:46 +08:00
def __call__(self, x: mx.array, cache):
2024-10-21 00:04:34 +08:00
x = self.embeddings(x)
2024-10-02 18:48:15 +08:00
if cache is None:
2024-10-21 00:04:34 +08:00
cache = [None] * len(self.layers)
hidden = x
2024-11-06 23:35:46 +08:00
for layer, c in zip(self.layers, cache):
hidden = layer(hidden, c)
return self.norm_f(hidden)
2024-10-02 18:48:15 +08:00
2024-11-06 23:35:46 +08:00
2024-10-02 18:48:15 +08:00
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
2024-11-06 23:35:46 +08:00
self.model_type = args.model_type
self.backbone = Mamba2(args)
2024-10-02 18:48:15 +08:00
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
2024-11-06 23:35:46 +08:00
def __call__(self, inputs: mx.array, cache=None):
hidden = self.backbone(inputs, cache)
2024-10-02 18:48:15 +08:00
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(hidden)
2024-10-02 18:48:15 +08:00
else:
logits = self.lm_head(hidden)
2024-10-02 18:48:15 +08:00
return logits
2024-11-06 23:35:46 +08:00
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
2024-11-06 23:35:46 +08:00
2024-10-31 04:23:13 +08:00
@property
def layers(self):
return self.backbone.layers