This commit is contained in:
Goekdeniz-Guelmez
2024-10-30 21:23:13 +01:00
parent ffc7ab06a0
commit 58b448dc0b
4 changed files with 1007 additions and 1285 deletions

View File

@@ -1,275 +1,7 @@
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from typing import Optional, Tuple, Union
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
use_cache: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
# Ensure in_channels and out_channels are the same for depthwise conv
assert in_channels == out_channels, "In and out channels must be the same for depthwise convolution"
# Ensure groups is equal to in_channels for depthwise conv
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
# Initialize weight with shape (out_channels, kernel_size, 1)
self.weight = mx.random.normal((out_channels, kernel_size, 1))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x, cache=None):
B, L, C = x.shape
_, K, _ = self.weight.shape
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=self.groups)
if self.bias is not None:
y = y + self.bias
return y, x[:, -K + 1 :, :]
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.intermediate_size = args.intermediate_size
self.time_step_rank = args.time_step_rank
self.conv_kernel_size = args.conv_kernel
self.hidden_size = args.hidden_size
self.state_size = args.state_size
self.num_heads = args.num_heads
self.head_dim = args.hidden_size // args.num_heads
self.n_groups = args.n_groups
# projection_size = 2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads
projection_size = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(
args.hidden_size,
projection_size,
bias=args.use_bias
)
# self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
self.conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=args.conv_kernel,
bias=args.use_conv_bias,
groups=self.conv_dim,
padding=args.conv_kernel - 1
)
self.A_log = mx.zeros(args.num_heads)
self.D = mx.ones((args.num_heads,))
self.dt_bias = mx.zeros(args.num_heads)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
def ssm_step(self, x, state, dt):
A = -mx.exp(self.A_log)
D = self.D
dt = nn.softplus(dt + self.dt_bias)
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
batch_size = B.shape[0]
B = B.reshape(batch_size, self.n_groups, self.state_size)
C = C.reshape(batch_size, -1, self.state_size)
dt = dt.reshape(batch_size, self.num_heads, 1)
A = A.reshape(1, self.num_heads, 1)
if state is None:
new_state = dt * B
else:
new_state = dt * (B + state * mx.exp(dt * A))
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
y = y + D * x[:, :self.num_heads]
return y, new_state
def __call__(self, x, cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
outputs = []
for t in range(T):
xt = x[:, t, :]
zxbcdt = self.in_proj(xt)
z, xBC, dt = mx.split(
zxbcdt,
# indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
indices_or_sections=[
self.intermediate_size,
self.intermediate_size + 2 * self.state_size,
self.num_heads
],
axis=-1
)
# Use the new DepthWiseConv1d with caching
conv_out, cache[0] = self.conv1d(mx.expand_dims(z, 1), cache[0])
z = conv_out.squeeze(1)
z = nn.silu(z)
y_t, cache[1] = self.ssm_step(z, cache[1], dt)
xBC = nn.silu(xBC)
# Element-wise multiplication
output_t = y_t[:, :, None] * xBC[:, None, :]
output_t = self.norm(output_t)
output_t = output_t.sum(axis=1)
output_t = self.out_proj(output_t)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
return output
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
# self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers
# ------
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -296,130 +28,79 @@ class ModelArgs(BaseModelArgs):
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
use_cache: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
use_cache: bool = True
intermediate_size: int = None
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
def selective_scan(x, A, B, C, chunk_size):
"""
Selective scan implementation for training.
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
Arguments
x: (batch, seqlen, n_heads, d_head)
A: (batch, seqlen, n_heads)
B: (batch, seqlen, n_heads, d_state)
C: (batch, seqlen, n_heads, d_state)
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
Return
y: (batch, seqlen, n_heads, d_head)
"""
assert x.shape[1] % chunk_size == 0
def silu(x):
return x * mx.sigmoid(x)
def ssd(x, A, B, C, chunk_size):
# Replace einsum operations with explicit reshape and matrix multiply
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
# Reshape into chunks
def chunk_reshape(m):
shape = list(m.shape)
shape[1:2] = [shape[1] // chunk_size, chunk_size]
return m.reshape(shape)
state = mx.zeros((batch, nheads, dim, B.shape[-1]))
outputs = []
x, A, B, C = map(chunk_reshape, (x, A, B, C))
A = mx.transpose(A, [0, 3, 1, 2])
for i in range(0, seqlen, chunk_size):
chunk = slice(i, min(i + chunk_size, seqlen))
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
# Replace einsum with explicit operations
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size]
state = state * mx.expand_dims(dA, axis=-1) + dBx
# Replace einsum with explicit operations
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
outputs.append(y)
# Compute cumulative sums
A_cumsum = mx.cumsum(A, axis=-1)
return mx.concatenate(outputs, axis=1), state
# Process chunks
L = mx.exp(selective_cumsum(A))
Y_diag = mx.einsum('bclhn,bcshn,bhcls,bcshp->bclhp', C, B, L, x)
decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum)
states = mx.einsum('bclhn,bhcl,bclhp->bchpn', B, decay_states, x)
initial_states = mx.zeros_like(states[:, :1])
states = mx.concatenate([initial_states, states], axis=1)
decay_chunk = mx.exp(selective_cumsum(mx.pad(A_cumsum[..., -1], ((0,0), (0,0), (1,0)))))
new_states = mx.einsum('bhzc,bchpn->bzhpn', decay_chunk, states)
states = new_states[:, :-1]
state_decay_out = mx.exp(A_cumsum)
Y_off = mx.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
Y = (Y_diag + Y_off).reshape((-1, x.shape[1] * chunk_size, *Y_diag.shape[-2:]))
return Y
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
self.weight = mx.random.normal((in_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x: mx.array, cache=None) -> mx.array:
B, L, C = x.shape
K = self.kernel_size
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
if cache is not None and 'conv_states' in cache:
conv_states = cache['conv_states']
if conv_states is not None:
assert conv_states.shape[0] == B, "Cache batch size mismatch"
assert conv_states.shape[2] == C, "Cache channel count mismatch"
x = mx.concatenate([conv_states, x], axis=1)
# Process each channel independently
outputs = []
for c in range(C):
x_c = x[:, :, c]
x_c = mx.expand_dims(x_c, axis=1)
w_c = self.weight[c]
if w_c.ndim == 2:
w_c = mx.expand_dims(w_c, axis=0)
elif w_c.ndim == 1:
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
# Apply convolution
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=0
)
if self.bias is not None:
y_c = y_c + self.bias[c]
outputs.append(mx.squeeze(y_c, axis=1))
y = mx.stack(outputs, axis=-1)
# Update cache
if cache is not None:
cache['conv_states'] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y
def selective_cumsum(x: mx.array) -> mx.array:
"""Stable selective cumulative sum calculation."""
T = x.shape[-1]
x = mx.repeat(x[..., None], T, axis=-1)
mask = mx.tril(mx.ones((T, T)), k=-1)
x = x * mask
x_cumsum = mx.cumsum(x, axis=-2)
mask = mx.tril(mx.ones((T, T)), k=0)
return mx.where(mask, x_cumsum, float('-inf'))
class Mamba2Block(nn.Module):
@@ -427,165 +108,250 @@ class Mamba2Block(nn.Module):
super().__init__()
self.args = args
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
# Project input to get various components [z, x, B, C, dt]
projection_size = (2 * args.intermediate_size + 2 * args.n_groups * args.state_size + args.num_heads)
self.in_proj = nn.Linear(
args.hidden_size,
projection_size,
bias=args.use_bias
)
conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
# Convolution layer
conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=args.conv_kernel,
groups=conv_dim,
bias=args.use_conv_bias,
padding=args.conv_kernel - 1
padding=args.conv_kernel - 1,
bias=args.use_conv_bias
)
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
# SSM parameters
self.dt_bias = mx.zeros(args.num_heads)
self.A_log = mx.zeros(args.num_heads)
self.D = mx.ones(args.num_heads)
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
# Output projections
self.norm = nn.RMSNorm(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache=None) -> mx.array:
# return self.forward_training(x) if x.shape[1] > 1 else self.forward_inference(x, cache)
def __call__(self, x: mx.array, cache=None):
if cache is not None:
return self.step(x, cache)
# Regular forward pass code remains the same...
d_model = self.args.intermediate_size
d_state = self.args.state_size
n_heads = self.args.num_heads
# def forward_training(self, u: mx.array) -> mx.array:
# # Reset cache during training
# self.cache = None
A = -mx.exp(self.A_log)
zxbcdt = self.in_proj(x)
splits = [d_model, d_model + 2 * d_state, n_heads]
z = zxbcdt[:, :, :splits[0]]
xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]]
dt = zxbcdt[:, :, -splits[2]:]
# # Input projection and splitting
# zxbcdt = self.in_proj(u)
# z, xBC, dt = mx.split(
# zxbcdt,
# [
# self.args.hidden_size,
# self.args.hidden_size + 2 * self.args.state_size
# ],
# axis=-1
# )
# # Time step processing
# dt = mx.clip(
# nn.softplus(dt + self.dt_bias),
# self.args.time_step_min,
# self.args.time_step_max
# )
# # Convolution processing
# xBC_t = mx.transpose(xBC, [0, 2, 1])
# conv_out = self.conv1d(xBC_t)
# xBC = mx.transpose(conv_out, [0, 2, 1])[:, :u.shape[1]]
# xBC = mx.sigmoid(xBC) * xBC # SiLU
# # Split states
# x, B, C = mx.split(
# xBC,
# [self.args.hidden_size, self.args.state_size],
# axis=-1
# )
# # Reshape for selective scan
# x = x.reshape((-1, x.shape[1], self.args.num_heads, self.args.head_dim))
# A = -mx.exp(self.A_log)
# # Apply selective scan
# y = selective_scan(
# x * dt[..., None],
# A * dt,
# B[..., None, :],
# C[..., None, :],
# self.args.chunk_size
# )
# # Output processing
# y = y + x * self.D[None, None, :, None]
# y = y.reshape((-1, y.shape[1], self.args.hidden_size))
# y = self.norm(y, z)
# y = self.out_proj(y)
# return y
# def forward_inference(self, u: mx.array, cache=None) -> mx.array:
# """
# u: (B, 1, D)
# cache: (h_cache, conv_cache)
# """
# """Single token processing during inference."""
# assert u.shape[1] == 1, "Inference mode expects single token"
# batch_size = u.shape[0]
# # Use provided cache or create new one
# self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None)
# # Project input
# zxbcdt = self.in_proj(u.squeeze(1)) # (B, 2D)
# d_mlp = (zxbcdt.shape[-1] - 2 * self.args.hidden_size - 2 * self.args.n_groups * self.args.state_size - self.args.num_heads) // 2
# # (1, 768) (1, 0) (1, 0) (1, 256) (1, 0) (1, 3328)
# y0, z0, x0, z, xBC, dt = mx.split(
# zxbcdt,
# [
# d_mlp,
# d_mlp,
# self.args.hidden_size,
# self.args.hidden_size + 2 * self.args.n_groups * self.args.state_size,
# self.args.num_heads
# ],
# axis=-1
# )
# # Update convolution state and apply
# conv_state = self.cache.update_conv_state(xBC)
# xBC = mx.sum(conv_state[:, :, -1] * mx.transpose(self.conv1d.weight, [1, 0, 2]), axis=-1) # (B, D) (4, 1792)
# if self.args.use_conv_bias:
# xBC = xBC + self.conv1d.bias
# xBC = mx.sigmoid(xBC) * xBC # SiLU (4, 1792)
# # Split states and ensure proper shapes
# a0, x, B, C = mx.split(
# xBC, # (4, 1792)
# [
# self.args.hidden_size,
# self.args.n_groups * self.args.state_size,
# self.args.n_groups * self.args.state_size
# ],
# axis=-1
# )
# # SSM step with explicit shapes
# A = -mx.exp(self.A_log) # (num_heads) (24,)
# print(A.shape) # (24,)
# print(dt.shape) # (1, 3328)
# dA = mx.exp(dt * A[None, :]) # Shape: (batch_size, num_heads) <------- her eis the error
# # Reshape x considering intermediate size
# # x shape should be (batch_size * num_heads, head_dim)
# x = mx.reshape(x, (batch_size, self.args.num_heads, -1))
# assert x.shape[-1] == self.args.head_dim, f"Head dimension mismatch: {x.shape[-1]} vs {self.args.head_dim}"
# B = mx.reshape(B, (batch_size, -1)) # Should be (batch_size, state_size)
# C = mx.reshape(C, (batch_size, -1)) # Should be (batch_size, state_size)
# # Compute dBx with explicit shapes
# dBx = mx.einsum('bh,bs,bhd->bhds', dt, B, x)
# ssm_state = self.cache.update_ssm_state(dA, dBx)
# y = mx.einsum('bhds,bs->bhd', ssm_state, C)
# y = y + x * self.D[None, :, None]
# y = mx.reshape(y, (batch_size, self.args.hidden_size))
# # Output processing
# y = self.norm(y, z)
# if d_mlp > 0:
# y = mx.cat([nn.silu(z0) * x0, y], axis=-1)
# y = self.out_proj(y)
# return mx.expand_dims(y, 1)
assert u.shape[1] == 1, "Inference mode expects single token"
batch_size = u.shape[0]
# Use provided cache or create new one
self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None)
# Project input
zxbcdt = self.in_proj(u.squeeze(1)) # (B, projection_size)
# Calculate splits based on model dimensions
d_mlp = self.args.intermediate_size
d_state = self.args.state_size * self.args.n_groups
# Split the projection into its components
splits = [
d_mlp, # y0
d_mlp, # z0
self.args.hidden_size, # x0
self.args.hidden_size, # z
d_state * 2, # xBC (includes both B and C)
self.args.num_heads # dt
]
y0, z0, x0, z, xBC, dt = mx.split(zxbcdt, splits[:-1], axis=-1)
# Update convolution state and apply
conv_state = self.cache.update_conv_state(xBC)
xBC = mx.sum(conv_state[:, :, -1] * mx.transpose(self.conv1d.weight, [1, 0, 2]), axis=-1)
if self.args.use_conv_bias:
xBC = xBC + self.conv1d.bias
xBC = mx.sigmoid(xBC) * xBC # SiLU
# Split states and reshape
x, BC = mx.split(xBC, [self.args.intermediate_size], axis=-1)
B, C = mx.split(BC, [d_state], axis=-1)
# Reshape for SSM computation
x = mx.reshape(x, (batch_size, self.args.num_heads, -1)) # (B, H, head_dim)
B = mx.reshape(B, (batch_size, self.args.num_heads, -1)) # (B, H, state_per_head)
C = mx.reshape(C, (batch_size, self.args.num_heads, -1)) # (B, H, state_per_head)
# Process dt to match expected shape
dt = mx.reshape(dt, (batch_size, self.args.num_heads)) # (B, H)
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)
xBC = silu(self.conv1d(xBC))
x = xBC[:, :, :d_model]
B = xBC[:, :, d_model:d_model + d_state]
C = xBC[:, :, -d_state:]
b, l, hp = x.shape
h = self.args.num_heads
p = hp // h
x = mx.reshape(x, (b, l, h, p))
y, ssm_state = ssd(x * mx.expand_dims(dt, -1), A * dt, B, C, self.args.chunk_size)
y = y + x * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (b, l, h * p))
y = self.norm(y + z)
# SSM step
A = -mx.exp(self.A_log) # (H,)
dA = mx.exp(dt * A[None, :]) # (B, H)
# Compute dBx
dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, x)
# Update SSM state and compute output
ssm_state = self.cache.update_ssm_state(dA, dBx)
y = mx.einsum('bhds,bhs->bhd', ssm_state, C)
y = y + x * self.D[None, :, None]
# Reshape output
y = mx.reshape(y, (batch_size, self.args.hidden_size))
# Final output processing
y = self.norm(y, z)
if d_mlp > 0:
y = mx.concat([nn.silu(z0) * x0, y], axis=-1)
y = self.out_proj(y)
if self.args.residual_in_fp32:
y = y.astype(mx.float32)
return y
def step(self, u: mx.array, cache):
batch_size = u.shape[0]
seq_len = u.shape[1]
outputs = []
# Initialize cache if needed
if cache.conv_states is None:
conv_dim = self.args.intermediate_size + 2 * self.args.state_size
cache.conv_states = mx.zeros((
batch_size,
self.args.conv_kernel - 1,
conv_dim
))
if cache.ssm_state is None:
cache.ssm_state = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
self.args.state_size
))
for pos in range(seq_len):
u_t = u[:, pos:pos+1, :]
zxbcdt = self.in_proj(u_t)
d_model = self.args.intermediate_size
d_state = self.args.state_size
n_heads = self.args.num_heads
z = zxbcdt[:, :, :d_model]
xBC = zxbcdt[:, :, d_model:d_model + 2*d_state + d_model]
dt = zxbcdt[:, :, -(n_heads):]
dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)
# Create a temporary cache dictionary for the convolution
conv_cache = {'conv_states': cache.conv_states}
xBC = self.conv1d(xBC, cache=conv_cache)
cache.conv_states = conv_cache['conv_states']
xBC = silu(xBC)
x = xBC[:, :, :d_model]
B = xBC[:, :, d_model:d_model + d_state]
C = xBC[:, :, -d_state:]
x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, d_state))
B = mx.broadcast_to(B, (batch_size, n_heads, d_state))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, d_state))
C = mx.broadcast_to(C, (batch_size, n_heads, d_state))
C = mx.expand_dims(C, axis=3)
A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
cache.ssm_state = cache.ssm_state * dA + dBx
y = mx.matmul(cache.ssm_state, C)
y = mx.squeeze(y, axis=-1)
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
y = self.out_proj(y)
if self.args.residual_in_fp32:
y = y.astype(mx.float32)
outputs.append(y)
return mx.concatenate(outputs, axis=1)
return mx.expand_dims(y, 1) # (B, 1, D)
class ResidualBlock(nn.Module):
@@ -594,11 +360,12 @@ class ResidualBlock(nn.Module):
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
def __call__(self, x: mx.array, cache=None) -> mx.array:
# x : (B, L, D)
return self.mixer(self.norm(x), cache) + x # (B, L, D)
class Mamba2(nn.Module):
class Mamba2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
@@ -606,12 +373,15 @@ class Mamba2(nn.Module):
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
def __call__(self, x: mx.array, cache=None) -> mx.array:
# x : (B, L)
x = self.embeddings(x)
# x : (B, L, D)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
for layer, layer_cache in zip(self.layers, cache):
x = layer(x, layer_cache)
return self.norm_f(x)
@@ -619,14 +389,13 @@ class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2Model(args)
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
def __call__(self, inputs: mx.array, cache=None) -> mx.array:
# inputs : (B, L)
B, T = inputs.shape
x = self.backbone(inputs, cache)
@@ -637,24 +406,19 @@ class Model(nn.Module):
logits = self.lm_head(x)
return logits
def make_cache(self):
return [Mamba2Cache() for _ in range(len(self.layers))]
def make_cache(self, batch_size=1):
return [Mamba2Cache(
batch_size=batch_size,
hidden_size=self.args.hidden_size,
state_size=self.args.state_size,
conv_kernel=self.args.conv_kernel,
num_heads=self.args.num_heads,
head_dim=self.args.head_dim
) for _ in range(len(self.backbone.layers))]
def sanitize(self, weights):
sanitized = {}
for k, v in weights.items():
if "conv1d.weight" in k:
# Ensure weights are in correct shape (channels, 1, kernel_size)
if v.ndim == 2:
v = mx.expand_dims(v, axis=1)
elif v.ndim == 1:
v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0)
sanitized[k] = v
else:
sanitized[k] = v
return sanitized
@property
def layers(self):
return self.backbone.layers
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights