2024-10-02 18:48:15 +08:00
|
|
|
import math
|
|
|
|
from dataclasses import dataclass, field
|
2024-11-10 23:35:07 +08:00
|
|
|
from typing import Tuple, Union
|
2024-10-20 22:11:39 +08:00
|
|
|
import mlx.core as mx
|
2024-10-21 00:04:34 +08:00
|
|
|
import mlx.nn as nn
|
2024-10-20 22:11:39 +08:00
|
|
|
|
2024-10-02 18:48:15 +08:00
|
|
|
from .base import BaseModelArgs
|
2024-11-10 23:57:03 +08:00
|
|
|
from .cache import MambaCache
|
2024-10-02 18:48:15 +08:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ModelArgs(BaseModelArgs):
|
2024-12-13 05:52:00 +08:00
|
|
|
model_type: str
|
2024-10-17 03:09:30 +08:00
|
|
|
num_heads: int
|
|
|
|
head_dim: int
|
|
|
|
vocab_size: int
|
|
|
|
hidden_size: int
|
|
|
|
state_size: int
|
|
|
|
num_hidden_layers: int
|
|
|
|
layer_norm_epsilon: float
|
|
|
|
expand: int
|
|
|
|
conv_kernel: int
|
|
|
|
n_groups: int
|
|
|
|
use_bias: bool
|
|
|
|
use_conv_bias: bool
|
2024-10-21 00:04:34 +08:00
|
|
|
initializer_range: float
|
2024-10-17 03:09:30 +08:00
|
|
|
residual_in_fp32: bool
|
|
|
|
rescale_prenorm_residual: bool
|
|
|
|
rms_norm: bool
|
|
|
|
chunk_size: int
|
|
|
|
tie_word_embeddings: bool
|
2025-01-21 02:44:05 +08:00
|
|
|
time_step_limit: Tuple[float, float]
|
|
|
|
time_step_rank: Union[int, str]
|
|
|
|
time_step_min: float
|
|
|
|
time_step_max: float
|
|
|
|
time_step_floor: float
|
2025-01-22 03:44:51 +08:00
|
|
|
norm_before_gate: bool = True
|
2024-10-02 18:48:15 +08:00
|
|
|
|
|
|
|
def __post_init__(self):
|
2024-11-06 23:35:46 +08:00
|
|
|
if not hasattr(self, "intermediate_size"):
|
|
|
|
self.intermediate_size = int(self.expand * self.hidden_size)
|
2024-10-02 18:48:15 +08:00
|
|
|
if not hasattr(self, "head_dim"):
|
|
|
|
self.head_dim = self.hidden_size // self.num_heads
|
|
|
|
if self.time_step_rank == "auto":
|
|
|
|
self.time_step_rank = math.ceil(self.hidden_size / 16)
|
|
|
|
|
2024-10-24 22:16:42 +08:00
|
|
|
|
2024-11-06 23:35:46 +08:00
|
|
|
def silu(x):
|
|
|
|
return x * mx.sigmoid(x)
|
|
|
|
|
2024-12-11 01:15:12 +08:00
|
|
|
|
2024-11-11 00:19:00 +08:00
|
|
|
class DepthWiseConv1d(nn.Module):
|
2024-11-22 05:25:58 +08:00
|
|
|
def __init__(self, channels, kernel_size, bias=True, padding=0):
|
2024-11-11 00:19:00 +08:00
|
|
|
super().__init__()
|
2024-11-22 05:25:58 +08:00
|
|
|
self.channels = channels
|
2024-11-11 00:19:00 +08:00
|
|
|
self.kernel_size = kernel_size
|
|
|
|
self.padding = padding
|
2024-12-11 00:34:44 +08:00
|
|
|
self.weight = mx.random.normal((channels, kernel_size, 1))
|
2024-11-22 05:25:58 +08:00
|
|
|
self.bias = mx.zeros((channels,)) if bias else None
|
2024-11-10 23:57:03 +08:00
|
|
|
|
2024-11-22 05:25:58 +08:00
|
|
|
def __call__(self, x, cache=None):
|
2024-11-11 00:19:00 +08:00
|
|
|
B, L, C = x.shape
|
2024-11-24 23:26:45 +08:00
|
|
|
_, K, _ = self.weight.shape
|
2024-11-10 23:57:03 +08:00
|
|
|
|
2024-11-11 00:19:00 +08:00
|
|
|
if cache is not None:
|
2024-11-22 05:25:58 +08:00
|
|
|
x = mx.concatenate([cache, x], axis=1)
|
|
|
|
else:
|
|
|
|
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
2024-11-11 00:19:00 +08:00
|
|
|
|
2024-12-13 04:08:33 +08:00
|
|
|
y = mx.conv_general(x, self.weight, groups=C)
|
|
|
|
y = y + self.bias
|
2024-11-22 05:25:58 +08:00
|
|
|
return y, x[:, -K + 1:, :]
|
2024-11-10 23:57:03 +08:00
|
|
|
|
|
|
|
|
2025-01-22 03:44:51 +08:00
|
|
|
def ssd_forward_attn(
|
|
|
|
x: mx.array,
|
|
|
|
dt: mx.array,
|
|
|
|
A: mx.array,
|
|
|
|
B: mx.array,
|
|
|
|
C: mx.array,
|
|
|
|
D: mx.array,
|
|
|
|
dt_bias: mx.array,
|
|
|
|
dt_min: float,
|
|
|
|
dt_max: float,
|
|
|
|
) -> Tuple[mx.array, mx.array]:
|
|
|
|
b, l, h, dh = x.shape
|
|
|
|
_, _, g, _ = B.shape
|
|
|
|
|
|
|
|
if dt_bias is not None:
|
|
|
|
dt = dt + dt_bias.reshape(1, 1, -1)
|
|
|
|
|
|
|
|
dt = nn.softplus(dt)
|
|
|
|
dt = mx.clip(dt, a_min=dt_min, a_max=dt_max)
|
|
|
|
|
|
|
|
B = mx.swapaxes(mx.swapaxes(B, 1, 3), 1, 2)
|
|
|
|
C = mx.swapaxes(C, 1, 2)
|
|
|
|
|
|
|
|
CB = C @ B
|
|
|
|
CB = mx.repeat(CB, repeats=h // g, axis=1)
|
|
|
|
|
|
|
|
dtA = dt * A.reshape(1, 1, -1)
|
|
|
|
dtA = mx.swapaxes(dtA, 1, 2)
|
|
|
|
|
|
|
|
decay = mx.exp(segsum(dtA))
|
|
|
|
|
|
|
|
surrogate_attention_matrix = mx.tril(CB * decay, 0)
|
|
|
|
|
|
|
|
dtx = dt.reshape(b, l, h, 1) * x
|
|
|
|
y = surrogate_attention_matrix @ dtx.swapaxes(1, 2)
|
|
|
|
y = mx.swapaxes(y, 1, 2)
|
|
|
|
|
|
|
|
decay = decay[:, :, -1, :].reshape(b, h, l).swapaxes(1, 2).reshape(b, l, h, 1)
|
|
|
|
B = mx.repeat(B, h // g, axis=1).swapaxes(2, 3)
|
|
|
|
dtxdecay = dtx * decay
|
|
|
|
dtxdecay = dtxdecay.swapaxes(1, 2).swapaxes(2, 3)
|
|
|
|
next_state = dtxdecay @ B
|
|
|
|
|
|
|
|
if D is not None:
|
|
|
|
y += x * D.reshape(1, 1, h, 1)
|
|
|
|
|
|
|
|
y = y.reshape(b, l, h * dh)
|
|
|
|
|
|
|
|
return y, next_state
|
|
|
|
|
|
|
|
|
|
|
|
def segsum(x):
|
|
|
|
l = x.shape[-1]
|
|
|
|
x = mx.repeat(x[..., None], l, axis=-1)
|
|
|
|
x = mx.tril(x, -1)
|
|
|
|
x_segsum = mx.cumsum(x, axis=-2)
|
|
|
|
return x_segsum
|
|
|
|
|
|
|
|
|
2024-11-06 23:35:46 +08:00
|
|
|
class Mamba2Block(nn.Module):
|
2024-10-21 00:04:34 +08:00
|
|
|
def __init__(self, args: ModelArgs):
|
2024-10-02 18:48:15 +08:00
|
|
|
super().__init__()
|
2024-11-06 23:35:46 +08:00
|
|
|
self.args = args
|
2024-11-10 21:36:26 +08:00
|
|
|
|
2025-01-22 04:01:39 +08:00
|
|
|
# Same dimensions as before
|
2024-11-22 05:01:28 +08:00
|
|
|
self.d_model = args.hidden_size
|
|
|
|
self.d_state = args.state_size
|
|
|
|
self.d_conv = args.conv_kernel
|
|
|
|
self.expand = args.expand
|
2025-01-14 04:28:43 +08:00
|
|
|
self.d_inner = int(self.expand * self.d_model)
|
2024-11-24 23:26:45 +08:00
|
|
|
self.n_groups = args.n_groups
|
2024-11-22 05:01:28 +08:00
|
|
|
self.n_heads = args.num_heads
|
|
|
|
self.d_head = self.d_inner // self.n_heads
|
2025-01-21 02:44:05 +08:00
|
|
|
self.chunk_size = args.chunk_size
|
2024-11-22 05:01:28 +08:00
|
|
|
|
|
|
|
# Input projection
|
2024-11-24 23:26:45 +08:00
|
|
|
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
|
2024-11-22 05:01:28 +08:00
|
|
|
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
|
2025-01-21 01:26:21 +08:00
|
|
|
|
2025-01-22 03:44:51 +08:00
|
|
|
# Parameters
|
2025-01-21 02:44:05 +08:00
|
|
|
self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
|
|
|
|
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
|
|
|
|
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
|
2025-01-14 04:28:43 +08:00
|
|
|
|
2025-01-22 03:44:51 +08:00
|
|
|
# Convolution
|
2025-01-14 04:28:43 +08:00
|
|
|
self.conv1d = DepthWiseConv1d(
|
|
|
|
channels=self.d_inner + 2 * self.n_groups * self.d_state,
|
|
|
|
kernel_size=self.d_conv,
|
|
|
|
bias=args.use_conv_bias,
|
|
|
|
padding=self.d_conv-1
|
|
|
|
)
|
2024-11-22 05:01:28 +08:00
|
|
|
|
2025-01-14 04:28:43 +08:00
|
|
|
# Output projections
|
2025-01-22 04:01:39 +08:00
|
|
|
self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon)
|
2024-11-22 05:01:28 +08:00
|
|
|
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
|
2024-11-10 23:35:07 +08:00
|
|
|
|
|
|
|
def __call__(self, u: mx.array, cache=None):
|
2024-11-22 05:01:28 +08:00
|
|
|
batch_size, seq_len, _ = u.shape
|
|
|
|
|
2025-01-22 04:01:39 +08:00
|
|
|
# Get or initialize states from cache
|
|
|
|
if cache is None:
|
|
|
|
cache = [None, None] # [conv_state, ssm_state]
|
|
|
|
conv_state, _ = cache # We ignore ssm_state as it's not used in the parallel version
|
|
|
|
|
2025-01-21 01:26:21 +08:00
|
|
|
# Project input
|
2025-01-22 03:44:51 +08:00
|
|
|
zxBCdt = self.in_proj(u)
|
2025-01-22 04:01:39 +08:00
|
|
|
|
2025-01-22 03:44:51 +08:00
|
|
|
# Split projections
|
2025-01-21 02:44:05 +08:00
|
|
|
z, xBC, dt = mx.split(
|
2025-01-22 03:44:51 +08:00
|
|
|
zxBCdt,
|
|
|
|
[self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state],
|
2025-01-21 02:44:05 +08:00
|
|
|
axis=-1
|
|
|
|
)
|
2025-01-22 04:01:39 +08:00
|
|
|
|
2025-01-22 03:44:51 +08:00
|
|
|
# Process convolution
|
2025-01-22 04:01:39 +08:00
|
|
|
xBC, conv_state = self.conv1d(xBC, conv_state)
|
2025-01-21 02:44:05 +08:00
|
|
|
xBC = silu(xBC)
|
2025-01-14 04:28:43 +08:00
|
|
|
xBC = xBC[:, :seq_len, :]
|
2025-01-22 04:01:39 +08:00
|
|
|
|
|
|
|
# Split conv output
|
2025-01-21 02:44:05 +08:00
|
|
|
x, B, C = mx.split(
|
2025-01-22 03:44:51 +08:00
|
|
|
xBC,
|
|
|
|
[self.d_inner, self.d_inner + self.d_state * self.n_groups],
|
2025-01-21 02:44:05 +08:00
|
|
|
axis=-1
|
|
|
|
)
|
2025-01-22 04:01:39 +08:00
|
|
|
|
2025-01-22 03:44:51 +08:00
|
|
|
# Reshape for SSM processing
|
|
|
|
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
|
2025-01-21 02:44:05 +08:00
|
|
|
B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1))
|
|
|
|
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
|
2025-01-22 04:01:39 +08:00
|
|
|
|
|
|
|
# Process with parallel attention
|
2025-01-22 03:44:51 +08:00
|
|
|
A = -mx.exp(self.A_log)
|
|
|
|
y, next_state = ssd_forward_attn(
|
|
|
|
x=x,
|
|
|
|
dt=dt,
|
|
|
|
A=A,
|
|
|
|
B=B,
|
|
|
|
C=C,
|
|
|
|
D=self.D,
|
|
|
|
dt_bias=self.dt_bias,
|
|
|
|
dt_min=self.args.time_step_min,
|
2025-01-22 04:01:39 +08:00
|
|
|
dt_max=self.args.time_step_max
|
2025-01-22 03:44:51 +08:00
|
|
|
)
|
2025-01-22 04:01:39 +08:00
|
|
|
|
|
|
|
# Apply normalization based on norm_before_gate setting
|
|
|
|
if self.args.norm_before_gate:
|
|
|
|
y = self.norm(y)
|
|
|
|
y = y * nn.silu(z)
|
|
|
|
else:
|
|
|
|
y = y * nn.silu(z)
|
|
|
|
y = self.norm(y)
|
|
|
|
|
|
|
|
# Final projection
|
2025-01-22 03:44:51 +08:00
|
|
|
y = self.out_proj(y)
|
2025-01-22 04:01:39 +08:00
|
|
|
|
|
|
|
# Update cache
|
|
|
|
cache[0] = conv_state
|
|
|
|
cache[1] = next_state
|
|
|
|
|
2025-01-22 03:44:51 +08:00
|
|
|
return y
|
2025-01-14 04:28:43 +08:00
|
|
|
|
2024-10-02 18:48:15 +08:00
|
|
|
|
2024-10-21 00:04:34 +08:00
|
|
|
class ResidualBlock(nn.Module):
|
|
|
|
def __init__(self, args: ModelArgs):
|
2024-10-02 18:48:15 +08:00
|
|
|
super().__init__()
|
2024-11-06 23:35:46 +08:00
|
|
|
self.residual_in_fp32 = args.residual_in_fp32
|
|
|
|
self.mixer = Mamba2Block(args)
|
2024-10-12 02:53:29 +08:00
|
|
|
self.norm = nn.RMSNorm(args.hidden_size)
|
2024-10-02 18:48:15 +08:00
|
|
|
|
2024-11-06 23:35:46 +08:00
|
|
|
def __call__(self, x: mx.array, cache):
|
2025-01-22 03:44:51 +08:00
|
|
|
# if self.residual_in_fp32:
|
|
|
|
# x = x.astype(mx.float32)
|
2024-11-10 23:35:07 +08:00
|
|
|
normed = self.norm(x)
|
|
|
|
output = self.mixer(normed, cache)
|
|
|
|
return output + x
|
2024-11-06 23:35:46 +08:00
|
|
|
|
2024-12-27 22:27:09 +08:00
|
|
|
|
2024-11-06 23:35:46 +08:00
|
|
|
class Mamba2(nn.Module):
|
2024-10-02 18:48:15 +08:00
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
super().__init__()
|
|
|
|
self.args = args
|
|
|
|
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
|
2024-10-21 00:04:34 +08:00
|
|
|
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
|
2024-10-02 18:48:15 +08:00
|
|
|
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
|
|
|
|
2024-11-06 23:35:46 +08:00
|
|
|
def __call__(self, x: mx.array, cache):
|
2024-10-21 00:04:34 +08:00
|
|
|
x = self.embeddings(x)
|
2024-10-02 18:48:15 +08:00
|
|
|
if cache is None:
|
2024-10-21 00:04:34 +08:00
|
|
|
cache = [None] * len(self.layers)
|
2024-11-10 23:35:07 +08:00
|
|
|
|
|
|
|
hidden = x
|
2024-11-06 23:35:46 +08:00
|
|
|
for layer, c in zip(self.layers, cache):
|
2024-11-10 23:35:07 +08:00
|
|
|
hidden = layer(hidden, c)
|
|
|
|
return self.norm_f(hidden)
|
2024-10-02 18:48:15 +08:00
|
|
|
|
2024-11-06 23:35:46 +08:00
|
|
|
|
2024-10-02 18:48:15 +08:00
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self, args: ModelArgs):
|
|
|
|
super().__init__()
|
|
|
|
self.args = args
|
2024-11-06 23:35:46 +08:00
|
|
|
self.model_type = args.model_type
|
|
|
|
self.backbone = Mamba2(args)
|
|
|
|
|
2024-10-02 18:48:15 +08:00
|
|
|
if not args.tie_word_embeddings:
|
|
|
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
|
|
|
|
2024-11-06 23:35:46 +08:00
|
|
|
def __call__(self, inputs: mx.array, cache=None):
|
2024-11-10 23:35:07 +08:00
|
|
|
hidden = self.backbone(inputs, cache)
|
|
|
|
|
2024-10-02 18:48:15 +08:00
|
|
|
if self.args.tie_word_embeddings:
|
2024-11-10 23:35:07 +08:00
|
|
|
logits = self.backbone.embeddings.as_linear(hidden)
|
2024-10-02 18:48:15 +08:00
|
|
|
else:
|
2024-11-10 23:35:07 +08:00
|
|
|
logits = self.lm_head(hidden)
|
|
|
|
|
2024-10-02 18:48:15 +08:00
|
|
|
return logits
|
2024-11-22 05:25:58 +08:00
|
|
|
|
|
|
|
def sanitize(self, weights):
|
|
|
|
for k, v in weights.items():
|
|
|
|
if "conv1d.weight" in k and v.shape[-1] != 1:
|
|
|
|
weights[k] = v.moveaxis(2, 1)
|
|
|
|
return weights
|
2024-11-06 23:35:46 +08:00
|
|
|
|
2024-11-10 23:35:07 +08:00
|
|
|
def make_cache(self):
|
2024-11-10 23:57:03 +08:00
|
|
|
return [MambaCache() for _ in range(len(self.layers))]
|
2024-11-06 23:35:46 +08:00
|
|
|
|
2024-10-31 04:23:13 +08:00
|
|
|
@property
|
|
|
|
def layers(self):
|
2024-11-10 23:35:07 +08:00
|
|
|
return self.backbone.layers
|