This commit is contained in:
Goekdeniz-Guelmez 2024-10-24 16:16:42 +02:00
parent a677638c4b
commit 7c8849e795
4 changed files with 757 additions and 274 deletions

View File

@ -340,21 +340,130 @@ class MambaCache(_BaseCache):
self.cache = v self.cache = v
class Mamba2Cache(_BaseCache): class Mamba2Cache:
conv_states: Optional[mx.array] = None batch_size: int
ssm_state: Optional[mx.array] = None intermediate_size: int
state_size: int
conv_kernel: int
num_heads: int
head_dim: int
def __init__(
self,
batch_size: int,
intermediate_size: int,
state_size: int,
conv_kernel: int,
num_heads: int,
head_dim: int
):
self.batch_size = batch_size
self.intermediate_size = intermediate_size
self.state_size = state_size
self.conv_kernel = conv_kernel
self.num_heads = num_heads
self.head_dim = head_dim
# Initialize conv state with proper dimensions
self.conv_dim = self.intermediate_size + 2 * self.state_size
self.conv_state = mx.zeros((batch_size, self.conv_dim, conv_kernel - 1))
# Initialize SSM state
self.ssm_state = mx.zeros((
batch_size,
num_heads,
head_dim,
state_size
))
def __getitem__(self, idx: int) -> Optional[mx.array]: def update_conv_state(self, x: mx.array) -> mx.array:
if idx == 0: """
return self.conv_states Update convolution state for incremental inference.
elif idx == 1: Args:
return self.ssm_states x: Input tensor containing projected values (B, conv_in_dim)
raise IndexError("Cache index must be 0 or 1") Returns:
Combined state tensor of shape (batch_size, conv_dim, kernel_size)
"""
# Handle input shape
if x.ndim == 1:
x = mx.expand_dims(x, 0) # Add batch dimension if needed
# Ensure batch size matches
assert x.shape[0] == self.batch_size, f"Batch size mismatch: {x.shape[0]} vs {self.batch_size}"
# Reshape x to match conv_dim
# The input x contains intermediate_size + 2 * state_size dimensions
x_reshaped = mx.reshape(x, (self.batch_size, -1))
x_padded = mx.pad(
x_reshaped,
[(0, 0), (0, self.conv_dim - x_reshaped.shape[1])],
mode='constant',
constant_values=0
)
# Expand dims for concatenation
x_expanded = mx.expand_dims(x_padded, -1) # Shape: (batch_size, conv_dim, 1)
# Roll the existing state left by 1
rolled_state = mx.roll(self.conv_state, shift=-1, axis=-1)
# Create update mask for the last position
update_pos = self.conv_kernel - 2
state_idx = mx.arange(self.conv_kernel - 1)
update_mask = state_idx == update_pos
# Broadcast mask to match state dimensions
update_mask = mx.broadcast_to(
mx.reshape(update_mask, (1, 1, -1)),
rolled_state.shape
)
# Update state with padded input
x_broadcast = mx.broadcast_to(x_expanded, (self.batch_size, self.conv_dim, 1))
self.conv_state = mx.where(
update_mask,
x_broadcast,
rolled_state
)
# Return concatenated state for convolution
return mx.concatenate([self.conv_state, x_expanded], axis=-1)
def __setitem__(self, idx: int, value: Optional[mx.array]): def update_ssm_state(self, dA: mx.array, dBx: mx.array) -> mx.array:
if idx == 0: """
self.conv_states = value Update SSM state for incremental inference.
elif idx == 1: Args:
self.ssm_states = value dA: State transition tensor of shape (batch_size, num_heads)
else: dBx: Input projection tensor of shape (batch_size, num_heads, head_dim, state_size)
raise IndexError("Cache index must be 0 or 1") Returns:
Updated SSM state of shape (batch_size, num_heads, head_dim, state_size)
"""
# Add necessary dimensions to dA for broadcasting
# dA shape: (batch_size, num_heads) -> (batch_size, num_heads, 1, 1)
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
# Ensure dBx has the correct shape
assert dBx.shape[-1] == self.state_size, f"dBx state dimension mismatch: {dBx.shape[-1]} vs {self.state_size}"
assert dBx.shape[-2] == self.head_dim, f"dBx head dimension mismatch: {dBx.shape[-2]} vs {self.head_dim}"
# Update state: state = dA * state + dBx
self.ssm_state = dA * self.ssm_state + dBx
return self.ssm_state
@classmethod
def get_cache(
cls,
args,
batch_size: int,
max_seq_length: Optional[int]
) -> "Mamba2Cache":
"""Create a new cache instance with the given parameters."""
return cls(
batch_size=batch_size,
intermediate_size=args.intermediate_size,
state_size=args.state_size,
conv_kernel=args.conv_kernel,
num_heads=args.num_heads,
head_dim=args.head_dim
)

View File

@ -258,3 +258,403 @@ class Model(nn.Module):
@property @property
def layers(self): def layers(self):
return self.backbone.layers return self.backbone.layers
# ------
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import Mamba2Cache
@dataclass
class ModelArgs(BaseModelArgs):
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
use_cache: bool = True
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones((hidden_size,))
self.variance_epsilon = eps
def __call__(self, hidden_states, gate=None):
if gate is not None:
hidden_states = hidden_states * nn.silu(gate)
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
def silu(x):
return x * mx.sigmoid(x)
def ssd(x, A, B, C, chunk_size):
# Replace einsum operations with explicit reshape and matrix multiply
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
state = mx.zeros((batch, nheads, dim, B.shape[-1]))
outputs = []
for i in range(0, seqlen, chunk_size):
chunk = slice(i, min(i + chunk_size, seqlen))
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
# Replace einsum with explicit operations
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size]
state = state * mx.expand_dims(dA, axis=-1) + dBx
# Replace einsum with explicit operations
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
outputs.append(y)
return mx.concatenate(outputs, axis=1), state
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
self.weight = mx.random.normal((in_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x: mx.array, cache=None) -> mx.array:
B, L, C = x.shape
K = self.kernel_size
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
if cache is not None and 'conv_states' in cache:
conv_states = cache['conv_states']
if conv_states is not None:
assert conv_states.shape[0] == B, "Cache batch size mismatch"
assert conv_states.shape[2] == C, "Cache channel count mismatch"
x = mx.concatenate([conv_states, x], axis=1)
# Process each channel independently
outputs = []
for c in range(C):
x_c = x[:, :, c]
x_c = mx.expand_dims(x_c, axis=1)
w_c = self.weight[c]
if w_c.ndim == 2:
w_c = mx.expand_dims(w_c, axis=0)
elif w_c.ndim == 1:
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
# Apply convolution
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=0
)
if self.bias is not None:
y_c = y_c + self.bias[c]
outputs.append(mx.squeeze(y_c, axis=1))
y = mx.stack(outputs, axis=-1)
# Update cache
if cache is not None:
cache['conv_states'] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=args.conv_kernel,
groups=conv_dim,
bias=args.use_conv_bias,
padding=args.conv_kernel - 1
)
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, x: mx.array, cache=None):
if cache is not None:
return self.step(x, cache)
# Regular forward pass code remains the same...
d_model = self.args.intermediate_size
d_state = self.args.state_size
n_heads = self.args.num_heads
A = -mx.exp(self.A_log)
zxbcdt = self.in_proj(x)
splits = [d_model, d_model + 2 * d_state, n_heads]
z = zxbcdt[:, :, :splits[0]]
xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]]
dt = zxbcdt[:, :, -splits[2]:]
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)
xBC = silu(self.conv1d(xBC))
x = xBC[:, :, :d_model]
B = xBC[:, :, d_model:d_model + d_state]
C = xBC[:, :, -d_state:]
b, l, hp = x.shape
h = self.args.num_heads
p = hp // h
x = mx.reshape(x, (b, l, h, p))
y, ssm_state = ssd(x * mx.expand_dims(dt, -1), A * dt, B, C, self.args.chunk_size)
y = y + x * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (b, l, h * p))
y = self.norm(y + z)
y = self.out_proj(y)
if self.args.residual_in_fp32:
y = y.astype(mx.float32)
return y
def step(self, u: mx.array, cache):
batch_size = u.shape[0]
seq_len = u.shape[1]
outputs = []
# Initialize cache if needed
if cache.conv_states is None:
conv_dim = self.args.intermediate_size + 2 * self.args.state_size
cache.conv_states = mx.zeros((
batch_size,
self.args.conv_kernel - 1,
conv_dim
))
if cache.ssm_state is None:
cache.ssm_state = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
self.args.state_size
))
for pos in range(seq_len):
u_t = u[:, pos:pos+1, :]
zxbcdt = self.in_proj(u_t)
d_model = self.args.intermediate_size
d_state = self.args.state_size
n_heads = self.args.num_heads
z = zxbcdt[:, :, :d_model]
xBC = zxbcdt[:, :, d_model:d_model + 2*d_state + d_model]
dt = zxbcdt[:, :, -(n_heads):]
dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)
# Create a temporary cache dictionary for the convolution
conv_cache = {'conv_states': cache.conv_states}
xBC = self.conv1d(xBC, cache=conv_cache)
cache.conv_states = conv_cache['conv_states']
xBC = silu(xBC)
x = xBC[:, :, :d_model]
B = xBC[:, :, d_model:d_model + d_state]
C = xBC[:, :, -d_state:]
x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, d_state))
B = mx.broadcast_to(B, (batch_size, n_heads, d_state))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, d_state))
C = mx.broadcast_to(C, (batch_size, n_heads, d_state))
C = mx.expand_dims(C, axis=3)
A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
cache.ssm_state = cache.ssm_state * dA + dBx
y = mx.matmul(cache.ssm_state, C)
y = mx.squeeze(y, axis=-1)
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
y = self.out_proj(y)
if self.args.residual_in_fp32:
y = y.astype(mx.float32)
outputs.append(y)
return mx.concatenate(outputs, axis=1)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
return self.mixer(self.norm(x), cache) + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
else:
logits = self.lm_head(x)
return logits
def make_cache(self):
return [Mamba2Cache() for _ in range(len(self.layers))]
def sanitize(self, weights):
sanitized = {}
for k, v in weights.items():
if "conv1d.weight" in k:
# Ensure weights are in correct shape (channels, 1, kernel_size)
if v.ndim == 2:
v = mx.expand_dims(v, axis=1)
elif v.ndim == 1:
v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0)
sanitized[k] = v
else:
sanitized[k] = v
return sanitized
@property
def layers(self):
return self.backbone.layers

View File

@ -88,6 +88,32 @@ class Mamba2LMHeadModel(nn.Module):
) )
self.lm_head.weight = self.backbone.embedding.weight self.lm_head.weight = self.backbone.embedding.weight
@staticmethod
def from_pretrained(huggingface_model_id: str, device: Device = None):
from transformers.utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.utils.hub import cached_file
config_path = cached_file(huggingface_model_id, CONFIG_NAME)
assert config_path, "Failed to get huggingface config file"
state_dict_path = cached_file(huggingface_model_id, WEIGHTS_NAME)
assert state_dict_path, "Failed to get huggingface state dict file"
config = json.load(open(config_path))
args = Mamba2Config(
d_model=config["d_model"],
n_layer=config["n_layer"],
vocab_size=config["vocab_size"],
pad_vocab_size_multiple=config["pad_vocab_size_multiple"],
)
map_location = "cpu" if device is None else device
state_dict = torch.load(
state_dict_path, weights_only=True, map_location=map_location, mmap=True
)
model = Mamba2LMHeadModel(args, device=device)
model.load_state_dict(state_dict)
model.eval()
return model
def forward( def forward(
self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None self, input_ids: LongTensor, h: list[InferenceCache] | list[None] | None = None
@ -193,7 +219,6 @@ class Mamba2(nn.Module):
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device)) self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device)) self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
self.D = nn.Parameter(torch.empty(args.nheads, device=device)) self.D = nn.Parameter(torch.empty(args.nheads, device=device))
self.norm = RMSNorm(args.d_inner, device=device) self.norm = RMSNorm(args.d_inner, device=device)
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device) self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)

View File

@ -1,6 +1,7 @@
import math import math
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple, Union from typing import Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -27,10 +28,10 @@ class ModelArgs(BaseModelArgs):
time_step_max: float time_step_max: float
time_step_floor: float time_step_floor: float
rescale_prenorm_residual: bool rescale_prenorm_residual: bool
use_cache: bool
rms_norm: bool rms_norm: bool
chunk_size: int chunk_size: int
tie_word_embeddings: bool tie_word_embeddings: bool
use_cache: bool = True
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto" time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2" model_type: str = "mamba2"
@ -43,114 +44,62 @@ class ModelArgs(BaseModelArgs):
if self.time_step_rank == "auto": if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16) self.time_step_rank = math.ceil(self.hidden_size / 16)
def selective_scan(x, A, B, C, chunk_size):
"""
Selective scan implementation for training.
class MambaRMSNormGated(nn.Module): Arguments
def __init__(self, hidden_size, eps=1e-6): x: (batch, seqlen, n_heads, d_head)
super().__init__() A: (batch, seqlen, n_heads)
self.weight = mx.ones((hidden_size,)) B: (batch, seqlen, n_heads, d_state)
self.variance_epsilon = eps C: (batch, seqlen, n_heads, d_state)
def __call__(self, hidden_states, gate=None): Return
if gate is not None: y: (batch, seqlen, n_heads, d_head)
hidden_states = hidden_states * nn.silu(gate) """
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) assert x.shape[1] % chunk_size == 0
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
# Reshape into chunks
def silu(x): def chunk_reshape(m):
return x * mx.sigmoid(x) shape = list(m.shape)
shape[1:2] = [shape[1] // chunk_size, chunk_size]
def ssd(x, A, B, C, chunk_size): return m.reshape(shape)
# Replace einsum operations with explicit reshape and matrix multiply
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
state = mx.zeros((batch, nheads, dim, B.shape[-1])) x, A, B, C = map(chunk_reshape, (x, A, B, C))
outputs = [] A = mx.transpose(A, [0, 3, 1, 2])
for i in range(0, seqlen, chunk_size): # Compute cumulative sums
chunk = slice(i, min(i + chunk_size, seqlen)) A_cumsum = mx.cumsum(A, axis=-1)
dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
# Replace einsum with explicit operations
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size]
state = state * mx.expand_dims(dA, axis=-1) + dBx
# Replace einsum with explicit operations
C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
outputs.append(y)
return mx.concatenate(outputs, axis=1), state # Process chunks
L = mx.exp(selective_cumsum(A))
Y_diag = mx.einsum('bclhn,bcshn,bhcls,bcshp->bclhp', C, B, L, x)
decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum)
states = mx.einsum('bclhn,bhcl,bclhp->bchpn', B, decay_states, x)
initial_states = mx.zeros_like(states[:, :1])
states = mx.concatenate([initial_states, states], axis=1)
decay_chunk = mx.exp(selective_cumsum(mx.pad(A_cumsum[..., -1], ((0,0), (0,0), (1,0)))))
new_states = mx.einsum('bhzc,bchpn->bzhpn', decay_chunk, states)
states = new_states[:, :-1]
state_decay_out = mx.exp(A_cumsum)
Y_off = mx.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
Y = (Y_diag + Y_off).reshape((-1, x.shape[1] * chunk_size, *Y_diag.shape[-2:]))
return Y
class DepthWiseConv1d(nn.Module): def selective_cumsum(x: mx.array) -> mx.array:
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): """Stable selective cumulative sum calculation."""
super().__init__() T = x.shape[-1]
self.in_channels = in_channels x = mx.repeat(x[..., None], T, axis=-1)
self.out_channels = out_channels mask = mx.tril(mx.ones((T, T)), k=-1)
self.kernel_size = kernel_size x = x * mask
self.padding = padding x_cumsum = mx.cumsum(x, axis=-2)
self.groups = groups if groups is not None else in_channels mask = mx.tril(mx.ones((T, T)), k=0)
return mx.where(mask, x_cumsum, float('-inf'))
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
self.weight = mx.random.normal((in_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x: mx.array, cache=None) -> mx.array:
B, L, C = x.shape
K = self.kernel_size
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
if cache is not None and 'conv_states' in cache:
conv_states = cache['conv_states']
if conv_states is not None:
assert conv_states.shape[0] == B, "Cache batch size mismatch"
assert conv_states.shape[2] == C, "Cache channel count mismatch"
x = mx.concatenate([conv_states, x], axis=1)
# Process each channel independently
outputs = []
for c in range(C):
x_c = x[:, :, c]
x_c = mx.expand_dims(x_c, axis=1)
w_c = self.weight[c]
if w_c.ndim == 2:
w_c = mx.expand_dims(w_c, axis=0)
elif w_c.ndim == 1:
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
# Apply convolution
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=0
)
if self.bias is not None:
y_c = y_c + self.bias[c]
outputs.append(mx.squeeze(y_c, axis=1))
y = mx.stack(outputs, axis=-1)
# Update cache
if cache is not None:
cache['conv_states'] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y
class Mamba2Block(nn.Module): class Mamba2Block(nn.Module):
@ -158,165 +107,172 @@ class Mamba2Block(nn.Module):
super().__init__() super().__init__()
self.args = args self.args = args
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads # Internal cache state
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) self.conv_state = None
self.ssm_state = None
# Project input to get various components
d_in_proj = (2 * args.intermediate_size + 2 * self.args.n_groups * args.state_size + args.num_heads)
self.in_proj = nn.Linear(
args.hidden_size,
d_in_proj,
bias=args.use_bias
)
conv_dim = args.intermediate_size + 2 * args.state_size # Convolution layer
self.conv1d = DepthWiseConv1d( conv_dim = args.intermediate_size + 2 * self.args.n_groups * args.state_size
self.conv1d = nn.Conv1d(
in_channels=conv_dim, in_channels=conv_dim,
out_channels=conv_dim, out_channels=conv_dim,
kernel_size=args.conv_kernel, kernel_size=args.conv_kernel,
groups=conv_dim, groups=conv_dim,
bias=args.use_conv_bias, padding=args.conv_kernel - 1,
padding=args.conv_kernel - 1 bias=args.use_conv_bias
) )
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range # SSM parameters
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range dt_init_floor = math.log(args.time_step_floor)
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range self.dt_bias = mx.zeros((args.num_heads,)) * args.initializer_range
self.A_log = mx.zeros((args.num_heads,)) * args.initializer_range
self.D = mx.zeros((args.num_heads,)) * args.initializer_range
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) # Output projections
self.norm = nn.RMSNorm(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
if args.rescale_prenorm_residual: def __call__(self, x: mx.array, cache=None) -> mx.array:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers) return self.forward_training(x) if x.shape[1] > 1 else self.forward_inference(x, cache)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, x: mx.array, cache=None): def forward_training(self, u: mx.array) -> mx.array:
if cache is not None: # Reset cache during training
return self.step(x, cache) self.cache = None
# Regular forward pass code remains the same...
d_model = self.args.intermediate_size
d_state = self.args.state_size
n_heads = self.args.num_heads
A = -mx.exp(self.A_log) # Input projection and splitting
zxbcdt = self.in_proj(x) zxbcdt = self.in_proj(u)
z, xBC, dt = mx.split(
splits = [d_model, d_model + 2 * d_state, n_heads] zxbcdt,
z = zxbcdt[:, :, :splits[0]] [
xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]] self.args.intermediate_size,
dt = zxbcdt[:, :, -splits[2]:] self.args.intermediate_size + 2 * self.args.state_size
],
axis=-1
)
# Time step processing
dt = mx.clip( dt = mx.clip(
nn.softplus(dt + self.dt_bias), nn.softplus(dt + self.dt_bias),
self.args.time_step_min, self.args.time_step_min,
self.args.time_step_max self.args.time_step_max
) )
dt = mx.maximum(dt, self.args.time_step_floor)
xBC = silu(self.conv1d(xBC)) # Convolution processing
xBC_t = mx.transpose(xBC, [0, 2, 1])
conv_out = self.conv1d(xBC_t)
xBC = mx.transpose(conv_out, [0, 2, 1])[:, :u.shape[1]]
xBC = mx.sigmoid(xBC) * xBC # SiLU
x = xBC[:, :, :d_model] # Split states
B = xBC[:, :, d_model:d_model + d_state] x, B, C = mx.split(
C = xBC[:, :, -d_state:] xBC,
[self.args.intermediate_size, self.args.state_size],
axis=-1
)
b, l, hp = x.shape # Reshape for selective scan
h = self.args.num_heads x = x.reshape((-1, x.shape[1], self.args.num_heads, self.args.head_dim))
p = hp // h A = -mx.exp(self.A_log)
x = mx.reshape(x, (b, l, h, p))
y, ssm_state = ssd(x * mx.expand_dims(dt, -1), A * dt, B, C, self.args.chunk_size) # Apply selective scan
y = y + x * mx.expand_dims(self.D, -1) y = selective_scan(
y = mx.reshape(y, (b, l, h * p)) x * dt[..., None],
A * dt,
B[..., None, :],
C[..., None, :],
self.args.chunk_size
)
y = self.norm(y + z) # Output processing
y = y + x * self.D[None, None, :, None]
y = y.reshape((-1, y.shape[1], self.args.intermediate_size))
y = self.norm(y, z)
y = self.out_proj(y) y = self.out_proj(y)
if self.args.residual_in_fp32:
y = y.astype(mx.float32)
return y return y
def step(self, u: mx.array, cache): def forward_inference(self, u: mx.array, cache=None) -> mx.array:
"""Single token processing during inference."""
assert u.shape[1] == 1, "Inference mode expects single token"
batch_size = u.shape[0] batch_size = u.shape[0]
seq_len = u.shape[1] # Use provided cache or create new one
outputs = [] self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None)
# Project input
zxbcdt = self.in_proj(mx.squeeze(u, 1))
parts = mx.split(
zxbcdt,
[
self.args.intermediate_size,
self.args.intermediate_size + 2 * self.args.state_size
],
axis=-1
)
z, xBC = parts[0], parts[1]
dt = zxbcdt[:, -self.args.num_heads:] # Extract dt separately
# Initialize cache if needed # Update convolution state and apply
if cache.conv_states is None: conv_state = self.cache.update_conv_state(xBC)
conv_dim = self.args.intermediate_size + 2 * self.args.state_size xBC = mx.sum(
cache.conv_states = mx.zeros(( conv_state * mx.transpose(self.conv1d.weight, [1, 0, 2]),
batch_size, axis=-1
self.args.conv_kernel - 1, )
conv_dim if self.args.use_conv_bias:
)) xBC = xBC + self.conv1d.bias
xBC = mx.sigmoid(xBC) * xBC # SiLU
if cache.ssm_state is None:
cache.ssm_state = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
self.args.state_size
))
for pos in range(seq_len): # Split states and ensure proper shapes
u_t = u[:, pos:pos+1, :] x_splits = mx.split(
zxbcdt = self.in_proj(u_t) xBC,
[self.args.intermediate_size, self.args.state_size],
d_model = self.args.intermediate_size axis=-1
d_state = self.args.state_size )
n_heads = self.args.num_heads x, B, C = x_splits[0], x_splits[1], x_splits[2]
z = zxbcdt[:, :, :d_model] # Process time steps - ensure proper broadcasting
xBC = zxbcdt[:, :, d_model:d_model + 2*d_state + d_model] dt = mx.reshape(dt, (batch_size, self.args.num_heads))
dt = zxbcdt[:, :, -(n_heads):] dt = mx.clip(
nn.softplus(dt + self.dt_bias[None, :]),
dt = mx.reshape(dt, (batch_size, n_heads)) self.args.time_step_min,
dt = mx.clip( self.args.time_step_max
nn.softplus(dt + self.dt_bias), )
self.args.time_step_min,
self.args.time_step_max # SSM step with explicit shapes
) A = -mx.exp(self.A_log)
dt = mx.maximum(dt, self.args.time_step_floor) dA = mx.exp(dt * A[None, :]) # Shape: (batch_size, num_heads)
# Reshape x considering intermediate size
# x shape should be (batch_size * num_heads, head_dim)
x = mx.reshape(x, (batch_size, self.args.num_heads, -1))
assert x.shape[-1] == self.args.head_dim, f"Head dimension mismatch: {x.shape[-1]} vs {self.args.head_dim}"
# Reshape B and C for ssm computation
B = mx.reshape(B, (batch_size, -1)) # Should be (batch_size, state_size)
C = mx.reshape(C, (batch_size, -1)) # Should be (batch_size, state_size)
# Compute dBx with explicit shapes
dBx = mx.einsum('bh,bs,bhd->bhds', dt, B, x)
ssm_state = self.cache.update_ssm_state(dA, dBx)
y = mx.einsum('bhds,bs->bhd', ssm_state, C)
y = y + x * self.D[None, :, None]
y = mx.reshape(y, (batch_size, self.args.intermediate_size))
# Output processing
y = self.norm(y, z)
y = self.out_proj(y)
# Create a temporary cache dictionary for the convolution return mx.expand_dims(y, 1)
conv_cache = {'conv_states': cache.conv_states}
xBC = self.conv1d(xBC, cache=conv_cache)
cache.conv_states = conv_cache['conv_states']
xBC = silu(xBC)
x = xBC[:, :, :d_model]
B = xBC[:, :, d_model:d_model + d_state]
C = xBC[:, :, -d_state:]
x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, d_state))
B = mx.broadcast_to(B, (batch_size, n_heads, d_state))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, d_state))
C = mx.broadcast_to(C, (batch_size, n_heads, d_state))
C = mx.expand_dims(C, axis=3)
A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
cache.ssm_state = cache.ssm_state * dA + dBx
y = mx.matmul(cache.ssm_state, C)
y = mx.squeeze(y, axis=-1)
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
y = self.out_proj(y)
if self.args.residual_in_fp32:
y = y.astype(mx.float32)
outputs.append(y)
return mx.concatenate(outputs, axis=1)
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
@ -325,11 +281,11 @@ class ResidualBlock(nn.Module):
self.mixer = Mamba2Block(args) self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size) self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache): def __call__(self, x: mx.array, cache=None) -> mx.array:
return self.mixer(self.norm(x), cache) + x return self.mixer(self.norm(x), cache) + x
class Mamba2(nn.Module): class Mamba2Model(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args self.args = args
@ -337,12 +293,12 @@ class Mamba2(nn.Module):
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)] self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache): def __call__(self, x: mx.array, cache=None) -> mx.array:
x = self.embeddings(x) x = self.embeddings(x)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache): for layer, layer_cache in zip(self.layers, cache):
x = layer(x, c) x = layer(x, layer_cache)
return self.norm_f(x) return self.norm_f(x)
@ -350,14 +306,12 @@ class Model(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args self.args = args
self.model_type = args.model_type self.backbone = Mamba2Model(args)
self.backbone = Mamba2(args)
if not args.tie_word_embeddings: if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None): def __call__(self, inputs: mx.array, cache=None) -> mx.array:
B, T = inputs.shape B, T = inputs.shape
x = self.backbone(inputs, cache) x = self.backbone(inputs, cache)
@ -368,24 +322,19 @@ class Model(nn.Module):
logits = self.lm_head(x) logits = self.lm_head(x)
return logits return logits
def make_cache(self, batch_size=1): def make_cache(self, batch_size=1):
return [Mamba2Cache() for _ in range(len(self.layers))] return [Mamba2Cache(
batch_size=batch_size,
intermediate_size=self.args.intermediate_size,
state_size=self.args.state_size,
conv_kernel=self.args.conv_kernel,
num_heads=self.args.num_heads,
head_dim=self.args.head_dim
) for _ in range(len(self.backbone.layers))]
def sanitize(self, weights): def sanitize(self, weights):
sanitized = {}
for k, v in weights.items(): for k, v in weights.items():
if "conv1d.weight" in k: if "conv1d.weight" in k and v.ndim == 3:
# Ensure weights are in correct shape (channels, 1, kernel_size) weights[k] = v.moveaxis(2, 1)
if v.ndim == 2: return weights
v = mx.expand_dims(v, axis=1)
elif v.ndim == 1:
v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0)
sanitized[k] = v
else:
sanitized[k] = v
return sanitized
@property
def layers(self):
return self.backbone.layers