Files
mlx-examples/llms/mlx_lm/models/mamba2 copy.py
2025-02-26 14:46:46 +01:00

899 lines
30 KiB
Python

import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
chunk_size: int
tie_word_embeddings: bool
time_step_limit: Tuple[float, float]
time_step_rank: Union[int, str]
time_step_min: float
time_step_max: float
time_step_floor: float
norm_before_gate: bool = True
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
def segsum(x):
"""Stable segment sum calculation.
`exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.
"""
T = x.shape[-1]
x = mx.expand_dims(x, -1)
x = mx.repeat(x, T, axis=-1)
mask = mx.tril(mx.ones((T, T), dtype=mx.bool_), k=-1)
x = mx.where(mask, x, 0)
x_segsum = mx.cumsum(x, axis=-2)
mask = mx.tril(mx.ones((T, T), dtype=mx.bool_), k=0)
x_segsum = mx.where(mask, x_segsum, -mx.inf)
return x_segsum
def ssd(x, A, B, C, chunk_size, initial_states=None):
"""Structured State Space Duality (SSD) - the core of Mamba-2
Arguments
x: (batch, seqlen, n_heads, d_head)
A: (batch, seqlen, n_heads)
B: (batch, seqlen, n_heads, d_state)
C: (batch, seqlen, n_heads, d_state)
Return
y: (batch, seqlen, n_heads, d_head)
final_state: final state for inference
"""
assert x.shape[1] % chunk_size == 0
# Rearrange into chunks
def rearrange_to_chunks(m):
shape = list(m.shape)
shape[1:2] = [shape[1] // chunk_size, chunk_size]
return m.reshape(shape)
x_chunked = rearrange_to_chunks(x)
A_chunked = rearrange_to_chunks(A)
B_chunked = rearrange_to_chunks(B)
C_chunked = rearrange_to_chunks(C)
# Transpose A for easier cumsum
A_chunked = mx.transpose(A_chunked, (0, 3, 1, 2)) # b c l h -> b h c l
A_cumsum = mx.cumsum(A_chunked, axis=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
L = mx.exp(segsum(A_chunked))
Y_diag = mx.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C_chunked, B_chunked, L, x_chunked)
# 2. Compute the state for each intra-chunk
decay_states = mx.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = mx.einsum("bclhn,bhcl,bclhp->bchpn", B_chunked, decay_states, x_chunked)
# 3. Compute the inter-chunk SSM recurrence
if initial_states is None:
initial_states = mx.zeros_like(states[:, :1])
states = mx.concatenate([initial_states, states], axis=1)
A_cumsum_last = A_cumsum[:, :, :, -1]
A_cumsum_padded = mx.pad(A_cumsum_last, [(0, 0), (0, 0), (1, 0)])
decay_chunk = mx.exp(segsum(A_cumsum_padded))
new_states = mx.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
state_decay_out = mx.exp(A_cumsum)
Y_off = mx.einsum("bclhn,bchpn,bhcl->bclhp", C_chunked, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms
Y_combined = Y_diag + Y_off
# Reshape back to original sequence shape
batch, chunks, chunk_len, heads, head_dim = Y_combined.shape
Y = Y_combined.reshape(batch, chunks * chunk_len, heads, head_dim)
return Y, final_state
def silu(x):
"""Applies the Sigmoid Linear Unit (SiLU), element-wise."""
return x * mx.sigmoid(x)
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.d_model = args.hidden_size
self.d_state = args.state_size
self.d_conv = args.conv_kernel
self.expand = args.expand
self.d_inner = int(self.expand * self.d_model)
self.n_groups = args.n_groups
self.n_heads = args.num_heads
self.d_head = self.d_inner // self.n_heads
self.chunk_size = args.chunk_size
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
# Use standard Conv1d with groups for depthwise convolution
conv_dim = self.d_inner + 2 * self.n_groups * self.d_state
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=self.d_conv,
groups=conv_dim, # Makes it depthwise
padding=self.d_conv-1,
bias=args.use_conv_bias
)
self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
def __call__(self, u, cache=None):
"""
Arguments
u: (batch, seqlen, d_model) input
cache: Optional tuple of (conv_state, ssm_state) for inference
Return (y, cache)
y: (batch, seqlen, d_model) output
cache: updated tuple of (conv_state, ssm_state) for inference
"""
if cache is not None:
return self.step(u, cache)
# Initialize cache if needed
if cache is None:
cache = [None, None] # Initialize with None values
# Compute projections
zxbcdt = self.in_proj(u)
# Split projections
d_inner = self.d_inner
d_state = self.n_groups * self.d_state
z, xBC, dt = mx.split(
zxbcdt,
[d_inner, d_inner + 2 * d_state],
axis=-1
)
# Process dt with softplus
dt = mx.softplus(dt + self.dt_bias) # (batch, seqlen, n_heads)
# Apply convolution to xBC
xBC_transposed = mx.transpose(xBC, (0, 2, 1)) # (batch, d, seqlen)
xBC_conv = self.conv1d(xBC_transposed)
xBC_conv = mx.transpose(xBC_conv, (0, 2, 1)) # (batch, seqlen, d)
xBC = silu(xBC_conv[:, :u.shape[1], :]) # Ensure we only keep seqlen elements
# Split xBC into x, B, C
x, B, C = mx.split(
xBC,
[d_inner, d_inner + d_state],
axis=-1
)
# Reshape x for heads
batch, seqlen = x.shape[0], x.shape[1]
x_reshaped = x.reshape(batch, seqlen, self.n_heads, self.d_head)
# Reshape B and C for SSM
B = B.reshape(batch, seqlen, 1, d_state)
C = C.reshape(batch, seqlen, 1, d_state)
# Apply SSM with SSD algorithm
A = -mx.exp(self.A_log) # (n_heads,)
A_dt = A * dt # (batch, seqlen, n_heads)
y, ssm_state = ssd(
x_reshaped * mx.expand_dims(dt, -1), # Scale x by dt
A_dt,
B,
C,
self.chunk_size
)
# Apply D and reshape
y = y + x_reshaped * mx.reshape(self.D, (1, 1, self.n_heads, 1))
y = y.reshape(batch, seqlen, d_inner)
# Apply norm and gating
y = self.norm(y, z)
# Final projection
y = self.out_proj(y)
# Create cache for inference
if seqlen == 1 and cache is not None:
conv_state = mx.zeros((batch, d_inner + 2 * d_state, self.d_conv))
conv_state = mx.update_slice(conv_state, xBC.reshape(batch, -1, 1), (0, 0, self.d_conv - 1))
cache[0] = conv_state
cache[1] = ssm_state
return y, cache
def step(self, u, cache):
"""Take an inference step for the current input and cache
Arguments
u: (batch, seqlen, d_model) - can be multiple tokens
cache: tuple of (conv_state, ssm_state)
Return (y, cache)
y: (batch, seqlen, d_model)
cache: updated cache object
"""
batch, seqlen = u.shape[0], u.shape[1]
# Initialize cache if it's None
if cache[0] is None or cache[1] is None:
d_state = self.n_groups * self.d_state
conv_dim = self.d_inner + 2 * d_state
conv_state = mx.zeros((batch, conv_dim, self.d_conv))
# Fix: use correct state size per head
state_per_head = d_state // self.n_heads
ssm_state = mx.zeros((batch, self.n_heads, self.d_head, state_per_head))
else:
conv_state, ssm_state = cache[0], cache[1]
# Project input
zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj)
# Split projections
d_inner = self.d_inner
d_state = self.n_groups * self.d_state
z, xBC, dt = mx.split(
zxbcdt,
[d_inner, d_inner + 2 * d_state],
axis=-1
)
# Process each token through the convolution sequentially
outputs = []
for i in range(seqlen):
# Get current token's input
xBC_i = xBC[:, i] # (batch, d_inner + 2*d_state)
dt_i = dt[:, i] # (batch, dt_size)
# Extract the head-specific dt values
dt_size = dt_i.shape[-1]
if dt_size % self.n_heads == 0:
# Reshape dt_i to extract the head-specific values
dt_reshaped = dt_i.reshape(batch, self.n_heads, dt_size // self.n_heads)
# Take the first element for each head
dt_heads = dt_reshaped[:, :, 0]
else:
# If we can't reshape, just take the first n_heads elements
dt_heads = dt_i[:, :self.n_heads]
# Process dt with softplus
dt_heads = nn.softplus(dt_heads + self.dt_bias.reshape(1, -1)) # (batch, n_heads)
# Update convolution state
conv_state = mx.roll(conv_state, shift=-1, axis=-1)
# Use slice_update instead of update_slice
# Reshape xBC_i to match the expected shape for the update
xBC_reshaped = xBC_i.reshape(batch, -1, 1)
# Create start_indices for the update
start_indices = mx.array([0, 0, self.d_conv - 1])
# Update the conv_state
conv_state = mx.slice_update(
conv_state,
xBC_reshaped,
start_indices,
axes=(0, 1, 2)
)
# Apply convolution step
weight = self.conv1d.weight
bias = self.conv1d.bias if self.args.use_conv_bias else None
xBC_conv = mx.sum(conv_state * weight.reshape(1, -1, self.d_conv), axis=-1)
if bias is not None:
xBC_conv = xBC_conv + bias
xBC_conv = silu(xBC_conv)
# Split xBC
x_i, B_i, C_i = mx.split(
xBC_conv,
[d_inner, d_inner + d_state],
axis=-1
)
# Apply SSM step
A = -mx.exp(self.A_log) # (n_heads,)
dA = mx.exp(dt_heads * A) # (batch, n_heads)
# Reshape x for heads
x_i = x_i.reshape(batch, self.n_heads, self.d_head)
# Reshape B and C for SSM with correct dimensions
state_per_head = d_state // self.n_heads
B_i_reshaped = B_i.reshape(batch, self.n_heads, state_per_head)
C_i_reshaped = C_i.reshape(batch, self.n_heads, state_per_head)
# Calculate dBx with the correctly shaped B
dBx = mx.einsum("bhn,bhp->bhpn", B_i_reshaped, x_i * mx.expand_dims(dt_heads, -1))
# Update SSM state
ssm_state = ssm_state * mx.reshape(dA, (batch, self.n_heads, 1, 1)) + dBx
# Calculate output with the correctly shaped C
y_i = mx.einsum("bhpn,bhn->bhp", ssm_state, C_i_reshaped)
# Apply D and reshape
y_i = y_i + x_i * mx.reshape(self.D, (1, self.n_heads, 1))
# Reshape y
y_i = y_i.reshape(batch, d_inner)
# Apply norm and gating (SwiGLU-like activation)
y_i = self.norm(y_i) # Just normalize without gating
y_i = y_i * nn.sigmoid(z[:, i]) # Apply gating separately
# Final projection
y_i = self.out_proj(y_i)
outputs.append(y_i)
# Stack outputs along sequence dimension
y = mx.stack(outputs, axis=1) # (batch, seqlen, d_model)
# Update cache
cache[0] = conv_state
cache[1] = ssm_state
return y
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
if self.residual_in_fp32:
x = x.astype(mx.float32)
normed = self.norm(x)
output = self.mixer(normed, cache)
return output + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
hidden = x
for layer, c in zip(self.layers, cache):
hidden = layer(hidden, c)
return self.norm_f(hidden)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
hidden = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(hidden)
else:
logits = self.lm_head(hidden)
return logits
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers
########################################################
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
chunk_size: int
tie_word_embeddings: bool
time_step_limit: Tuple[float, float]
time_step_rank: Union[int, str]
time_step_min: float
time_step_max: float
time_step_floor: float
norm_before_gate: bool = True
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
def segsum(x):
"""Stable segment sum calculation.
`exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.
"""
T = x.shape[-1]
x = mx.expand_dims(x, -1)
x = mx.repeat(x, T, axis=-1)
mask = mx.tril(mx.ones((T, T), dtype=mx.bool_), k=-1)
x = mx.where(mask, x, 0)
x_segsum = mx.cumsum(x, axis=-2)
mask = mx.tril(mx.ones((T, T), dtype=mx.bool_), k=0)
x_segsum = mx.where(mask, x_segsum, -mx.inf)
return x_segsum
def ssd(x, A, B, C, chunk_size, initial_states=None):
"""Structured State Space Duality (SSD) - the core of Mamba-2
Arguments
x: (batch, seqlen, n_heads, d_head)
A: (batch, seqlen, n_heads)
B: (batch, seqlen, n_heads, d_state)
C: (batch, seqlen, n_heads, d_state)
Return
y: (batch, seqlen, n_heads, d_head)
final_state: final state for inference
"""
assert x.shape[1] % chunk_size == 0
# Rearrange into chunks
def rearrange_to_chunks(m):
shape = list(m.shape)
shape[1:2] = [shape[1] // chunk_size, chunk_size]
return m.reshape(shape)
x_chunked = rearrange_to_chunks(x)
A_chunked = rearrange_to_chunks(A)
B_chunked = rearrange_to_chunks(B)
C_chunked = rearrange_to_chunks(C)
# Transpose A for easier cumsum
A_chunked = mx.transpose(A_chunked, (0, 3, 1, 2)) # b c l h -> b h c l
A_cumsum = mx.cumsum(A_chunked, axis=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
L = mx.exp(segsum(A_chunked))
Y_diag = mx.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C_chunked, B_chunked, L, x_chunked)
# 2. Compute the state for each intra-chunk
decay_states = mx.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
states = mx.einsum("bclhn,bhcl,bclhp->bchpn", B_chunked, decay_states, x_chunked)
# 3. Compute the inter-chunk SSM recurrence
if initial_states is None:
initial_states = mx.zeros_like(states[:, :1])
states = mx.concatenate([initial_states, states], axis=1)
A_cumsum_last = A_cumsum[:, :, :, -1]
A_cumsum_padded = mx.pad(A_cumsum_last, [(0, 0), (0, 0), (1, 0)])
decay_chunk = mx.exp(segsum(A_cumsum_padded))
new_states = mx.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
state_decay_out = mx.exp(A_cumsum)
Y_off = mx.einsum("bclhn,bchpn,bhcl->bclhp", C_chunked, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms
Y_combined = Y_diag + Y_off
# Reshape back to original sequence shape
batch, chunks, chunk_len, heads, head_dim = Y_combined.shape
Y = Y_combined.reshape(batch, chunks * chunk_len, heads, head_dim)
return Y, final_state
def silu(x):
"""Applies the Sigmoid Linear Unit (SiLU), element-wise."""
return x * mx.sigmoid(x)
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.d_model = args.hidden_size
self.d_state = args.state_size
self.d_conv = args.conv_kernel
self.expand = args.expand
self.d_inner = int(self.expand * self.d_model)
self.n_groups = args.n_groups
self.n_heads = args.num_heads
self.d_head = self.d_inner // self.n_heads
self.chunk_size = args.chunk_size
d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads
self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias)
self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range
self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range
self.D = mx.random.normal((self.n_heads,)) * args.initializer_range
# Use standard Conv1d with groups for depthwise convolution
conv_dim = self.d_inner + 2 * self.n_groups * self.d_state
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=self.d_conv,
groups=conv_dim,
padding=self.d_conv-1,
bias=args.use_conv_bias
)
self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
def __call__(self, u, cache=None):
"""
Arguments
u: (batch, seqlen, d_model) input
cache: Optional tuple of (conv_state, ssm_state) for inference
Return (y, cache)
y: (batch, seqlen, d_model) output
cache: updated tuple of (conv_state, ssm_state) for inference
"""
if cache is not None:
return self.step(u, cache)
# Initialize cache if needed
if cache is None:
cache = [None, None] # Initialize with None values
# Compute projections
zxbcdt = self.in_proj(u)
# Split projections
d_inner = self.d_inner
d_state = self.n_groups * self.d_state
z, xBC, dt = mx.split(
zxbcdt,
[d_inner, d_inner + 2 * d_state],
axis=-1
)
# Process dt with softplus
dt = mx.softplus(dt + self.dt_bias) # (batch, seqlen, n_heads)
# Apply convolution to xBC
xBC_transposed = mx.transpose(xBC, (0, 2, 1)) # (batch, d, seqlen)
xBC_conv = self.conv1d(xBC_transposed)
xBC_conv = mx.transpose(xBC_conv, (0, 2, 1)) # (batch, seqlen, d)
xBC = silu(xBC_conv[:, :u.shape[1], :]) # Ensure we only keep seqlen elements
# Split xBC into x, B, C
x, B, C = mx.split(
xBC,
[d_inner, d_inner + d_state],
axis=-1
)
# Reshape x for heads
batch, seqlen = x.shape[0], x.shape[1]
x_reshaped = x.reshape(batch, seqlen, self.n_heads, self.d_head)
# Reshape B and C for SSM
B = B.reshape(batch, seqlen, 1, d_state)
C = C.reshape(batch, seqlen, 1, d_state)
# Apply SSM with SSD algorithm
A = -mx.exp(self.A_log) # (n_heads,)
A_dt = A * dt # (batch, seqlen, n_heads)
y, ssm_state = ssd(
x_reshaped * mx.expand_dims(dt, -1), # Scale x by dt
A_dt,
B,
C,
self.chunk_size
)
# Apply D and reshape
y = y + x_reshaped * mx.reshape(self.D, (1, 1, self.n_heads, 1))
y = y.reshape(batch, seqlen, d_inner)
# Apply norm and gating
y = self.norm(y, z)
# Final projection
y = self.out_proj(y)
# Create cache for inference
if seqlen == 1 and cache is not None:
conv_state = mx.zeros((batch, d_inner + 2 * d_state, self.d_conv))
conv_state = mx.update_slice(conv_state, xBC.reshape(batch, -1, 1), (0, 0, self.d_conv - 1))
cache[0] = conv_state
cache[1] = ssm_state
return y
def step(self, u, cache):
"""Take an inference step for the current input and cache
Arguments
u: (batch, seqlen, d_model) - can be multiple tokens
cache: tuple of (conv_state, ssm_state)
Return (y, cache)
y: (batch, seqlen, d_model)
cache: updated cache object
"""
batch, seqlen = u.shape[0], u.shape[1]
# Initialize cache if it's None
if cache[0] is None or cache[1] is None:
d_state = self.n_groups * self.d_state
conv_dim = self.d_inner + 2 * d_state
conv_state = mx.zeros((batch, conv_dim, self.d_conv))
# Fix: use correct state size per head
state_per_head = d_state // self.n_heads
ssm_state = mx.zeros((batch, self.n_heads, self.d_head, state_per_head))
else:
conv_state, ssm_state = cache[0], cache[1]
# Project input
zxbcdt = self.in_proj(u) # (batch, seqlen, d_in_proj)
# Split projections
d_inner = self.d_inner
d_state = self.n_groups * self.d_state
z, xBC, dt = mx.split(
zxbcdt,
[d_inner, d_inner + 2 * d_state],
axis=-1
)
# Process dt with softplus once for all tokens
dt_heads = dt.reshape(batch, seqlen, -1)[:, :, :self.n_heads]
dt_heads = nn.softplus(dt_heads + self.dt_bias.reshape(1, 1, -1))
# Pre-compute dA for all tokens
A = -mx.exp(self.A_log) # (n_heads,)
dA = mx.exp(dt_heads * A.reshape(1, 1, -1)) # (batch, seqlen, n_heads)
# Get convolution weights
weight = self.conv1d.weight # shape: (out_channels, 1, kernel_size)
bias = self.conv1d.bias if self.args.use_conv_bias else None
# Process each token through the convolution sequentially
outputs = []
for i in range(seqlen):
# Get current token's input
xBC_i = xBC[:, i] # (batch, d_inner + 2*d_state)
# Update convolution state
conv_state = mx.roll(conv_state, shift=-1, axis=-1)
# Update the last column of conv_state
conv_state = mx.slice_update(
conv_state,
xBC_i.reshape(batch, -1, 1),
mx.array([0, 0, self.d_conv - 1]),
axes=(0, 1, 2)
)
# Apply convolution step - manually handle the depthwise conv
# For a depthwise conv, we need to process each channel separately
# conv_state shape: (batch, channels, kernel_size)
# weight shape: (channels, 1, kernel_size) for depthwise conv
# Reshape weight to match conv_state for element-wise multiplication
# and then sum along the kernel dimension
weight_reshaped = weight.reshape(conv_state.shape[1], self.d_conv)
xBC_conv = mx.sum(conv_state * weight_reshaped.reshape(1, -1, self.d_conv), axis=-1)
if bias is not None:
xBC_conv = xBC_conv + bias
xBC_conv = silu(xBC_conv)
# Split xBC
x_i, BC_rest = mx.split(xBC_conv, [d_inner], axis=-1)
B_i, C_i = mx.split(BC_rest, [d_state], axis=-1)
# Reshape x for heads
x_i = x_i.reshape(batch, self.n_heads, self.d_head)
# Reshape B and C for SSM
state_per_head = d_state // self.n_heads
B_i_reshaped = B_i.reshape(batch, self.n_heads, state_per_head)
C_i_reshaped = C_i.reshape(batch, self.n_heads, state_per_head)
# Get current token's dt and dA
dt_i = dt_heads[:, i] # (batch, n_heads)
dA_i = dA[:, i] # (batch, n_heads)
# Calculate dBx
dBx = mx.einsum("bhn,bhp->bhpn", B_i_reshaped, x_i * mx.expand_dims(dt_i, -1))
# Update SSM state
ssm_state = ssm_state * mx.reshape(dA_i, (batch, self.n_heads, 1, 1)) + dBx
# Calculate output with the correctly shaped C
y_i = mx.einsum("bhpn,bhn->bhp", ssm_state, C_i_reshaped)
# Apply D and reshape
y_i = y_i + x_i * self.D.reshape(1, self.n_heads, 1)
# Reshape y
y_i = y_i.reshape(batch, d_inner)
# Apply norm and gating
y_i = self.norm(y_i) * nn.sigmoid(z[:, i])
# Final projection
y_i = self.out_proj(y_i)
outputs.append(y_i)
# Stack outputs along sequence dimension
y = mx.stack(outputs, axis=1)
# Update cache
cache[0] = conv_state
cache[1] = ssm_state
return y
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
if self.residual_in_fp32:
x = x.astype(mx.float32)
normed = self.norm(x)
output = self.mixer(normed, cache)
return output + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
hidden = x
for layer, c in zip(self.layers, cache):
hidden = layer(hidden, c)
return self.norm_f(hidden)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
hidden = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(hidden)
else:
logits = self.lm_head(hidden)
return logits
def make_cache(self):
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):
return self.backbone.layers