This commit is contained in:
Goekdeniz-Guelmez
2024-10-30 21:23:13 +01:00
parent ffc7ab06a0
commit 58b448dc0b
4 changed files with 1007 additions and 1285 deletions

View File

@@ -32,259 +32,272 @@ class ModelArgs(BaseModelArgs):
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
intermediate_size: int = None
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
self.intermediate_size = int(self.expand * self.hidden_size)
self.intermediate_size = int(self.expand * self.hidden_size) # E*D = ED
if not hasattr(self, "head_dim"):
self.head_dim = self.hidden_size // self.num_heads
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
def selective_scan(x, A, B, C, chunk_size):
"""
Selective scan implementation for training.
Arguments
x: (batch, seqlen, n_heads, d_head)
A: (batch, seqlen, n_heads)
B: (batch, seqlen, n_heads, d_state)
C: (batch, seqlen, n_heads, d_state)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = mx.ones(hidden_size)
self.variance_epsilon = eps
Return
y: (batch, seqlen, n_heads, d_head)
"""
assert x.shape[1] % chunk_size == 0
# Reshape into chunks
def chunk_reshape(m):
shape = list(m.shape)
shape[1:2] = [shape[1] // chunk_size, chunk_size]
return m.reshape(shape)
x, A, B, C = map(chunk_reshape, (x, A, B, C))
A = mx.transpose(A, [0, 3, 1, 2])
# Compute cumulative sums
A_cumsum = mx.cumsum(A, axis=-1)
# Process chunks
L = mx.exp(selective_cumsum(A))
Y_diag = mx.einsum('bclhn,bcshn,bhcls,bcshp->bclhp', C, B, L, x)
decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum)
states = mx.einsum('bclhn,bhcl,bclhp->bchpn', B, decay_states, x)
initial_states = mx.zeros_like(states[:, :1])
states = mx.concatenate([initial_states, states], axis=1)
decay_chunk = mx.exp(selective_cumsum(mx.pad(A_cumsum[..., -1], ((0,0), (0,0), (1,0)))))
new_states = mx.einsum('bhzc,bchpn->bzhpn', decay_chunk, states)
states = new_states[:, :-1]
state_decay_out = mx.exp(A_cumsum)
Y_off = mx.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
Y = (Y_diag + Y_off).reshape((-1, x.shape[1] * chunk_size, *Y_diag.shape[-2:]))
return Y
def forward(self, hidden_states, gate=None):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(mx.float32)
def selective_cumsum(x: mx.array) -> mx.array:
"""Stable selective cumulative sum calculation."""
T = x.shape[-1]
x = mx.repeat(x[..., None], T, axis=-1)
mask = mx.tril(mx.ones((T, T)), k=-1)
x = x * mask
x_cumsum = mx.cumsum(x, axis=-2)
mask = mx.tril(mx.ones((T, T)), k=0)
return mx.where(mask, x_cumsum, float('-inf'))
if gate is not None:
hidden_states = hidden_states * nn.functional.silu(gate.to(mx.float32))
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * math.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
class Mamba2Block(nn.Module):
class Mamba2Mixer(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
# Model dimensions
self.hidden_size = args.hidden_size
self.num_heads = args.num_heads
self.head_dim = args.head_dim
self.ssm_state_size = args.state_size
self.n_groups = args.n_groups
self.intermediate_size = int(args.expand * args.hidden_size)
# Internal cache state
self.conv_state = None
self.ssm_state = None
# Convolution parameters
self.conv_kernel = args.conv_kernel
self.use_conv_bias = args.use_conv_bias
# Project input to get various components
d_in_proj = (2 * args.intermediate_size + 2 * self.args.n_groups * args.state_size + args.num_heads)
# Time step parameters
self.time_step_rank = int(args.time_step_rank)
self.time_step_min = args.time_step_min
self.time_step_max = args.time_step_max
# Processing parameters
self.chunk_size = args.chunk_size
self.layer_norm_epsilon = args.layer_norm_epsilon
# Calculate dimensions
self.conv_dim = (self.intermediate_size +
2 * self.n_groups * self.ssm_state_size)
projection_size = (self.intermediate_size +
self.conv_dim +
self.num_heads)
# Initialize layers
self.in_proj = nn.Linear(
args.hidden_size,
d_in_proj,
self.hidden_size,
projection_size,
bias=args.use_bias
)
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=self.conv_kernel,
groups=self.conv_dim,
padding=self.conv_kernel - 1,
bias=self.use_conv_bias
)
# Initialize parameters
self.dt_bias = mx.ones(self.num_heads)
A = mx.arange(1, self.num_heads + 1)
self.A_log = mx.log(A)
self.D = mx.ones(self.num_heads)
# Output layers
self.norm = MambaRMSNormGated(
self.intermediate_size,
eps=self.layer_norm_epsilon
)
self.out_proj = nn.Linear(
self.intermediate_size,
self.hidden_size,
bias=args.use_bias
)
# Convolution layer
conv_dim = args.intermediate_size + 2 * self.args.n_groups * args.state_size
self.conv1d = nn.Conv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=args.conv_kernel,
groups=conv_dim,
padding=args.conv_kernel - 1,
bias=args.use_conv_bias
)
# SSM parameters
dt_init_floor = math.log(args.time_step_floor)
self.dt_bias = mx.zeros((args.num_heads,)) * args.initializer_range
self.A_log = mx.zeros((args.num_heads,)) * args.initializer_range
self.D = mx.zeros((args.num_heads,)) * args.initializer_range
# Output projections
self.norm = nn.RMSNorm(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
def __call__(self, x: mx.array, cache=None) -> mx.array:
return self.forward_training(x) if x.shape[1] > 1 else self.forward_inference(x, cache)
def forward_training(self, u: mx.array) -> mx.array:
# Reset cache during training
self.cache = None
def reshape_into_chunks(self, tensor, pad_size, chunk_size):
if pad_size > 0:
pad_shape = list(tensor.shape)
pad_shape[1] = pad_size
padding = mx.zeros(pad_shape, dtype=tensor.dtype)
tensor = mx.concatenate([tensor, padding], axis=1)
# Input projection and splitting
zxbcdt = self.in_proj(u)
z, xBC, dt = mx.split(
zxbcdt,
[
self.args.intermediate_size,
self.args.intermediate_size + 2 * self.args.state_size
],
axis=-1
)
chunk_shape = list(tensor.shape)
chunk_shape[1] = -1
chunk_shape.insert(2, chunk_size)
return tensor.reshape(chunk_shape)
# Time step processing
def segment_sum(self, x):
return mx.cumsum(x, axis=-1)
def process_single_token(self, hidden_states, B, C, dt, cache):
batch_size = hidden_states.shape[0]
# Process convolution state
if cache is not None:
conv_state = cache.conv_states
# Roll the conv state and update the last position
conv_state = mx.roll(conv_state, shift=-1, axis=-1)
# Create new conv state with updated last position
new_conv_state = mx.array(conv_state)
new_conv_state = new_conv_state.at[:, :, -1].add(hidden_states)
conv_state = new_conv_state
# Compute convolution
conv_out = mx.sum(conv_state * self.conv1d.weight[:, 0, :], axis=-1)
if self.use_conv_bias:
conv_out = conv_out + self.conv1d.bias
# Apply SiLU activation
conv_out = mx.sigmoid(conv_out) * conv_out
else:
# Initialize new cache
conv_state = mx.zeros((batch_size, self.conv_dim, self.conv_kernel - 1))
conv_out = self.conv1d(hidden_states)
conv_out = mx.sigmoid(conv_out) * conv_out
# Process SSM
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
self.time_step_min,
self.time_step_max
)
# Convolution processing
xBC_t = mx.transpose(xBC, [0, 2, 1])
conv_out = self.conv1d(xBC_t)
xBC = mx.transpose(conv_out, [0, 2, 1])[:, :u.shape[1]]
xBC = mx.sigmoid(xBC) * xBC # SiLU
# Split states
x, B, C = mx.split(
xBC,
[self.args.intermediate_size, self.args.state_size],
axis=-1
)
# Reshape for selective scan
x = x.reshape((-1, x.shape[1], self.args.num_heads, self.args.head_dim))
A = -mx.exp(self.A_log)
dA = mx.exp(dt * A[None, :])
if cache is not None:
ssm_state = cache.ssm_states
else:
ssm_state = mx.zeros(
(batch_size, self.num_heads, self.head_dim, self.ssm_state_size)
)
# Compute SSM updates
dBx = mx.einsum('bh,bhs,bhd->bhds', dt, B, hidden_states)
next_state = ssm_state * dA[:, :, None, None] + dBx
y = mx.einsum('bhds,bhs->bhd', next_state, C)
# Add skip connection
y = y + hidden_states * self.D[None, :, None]
return y, conv_state, next_state
# Apply selective scan
y = selective_scan(
x * dt[..., None],
A * dt,
B[..., None, :],
C[..., None, :],
self.args.chunk_size
def process_long_sequence(self, hidden_states, B, C, dt, ssm_state):
batch_size, seq_len = hidden_states.shape[:2]
pad_size = self.chunk_size - (seq_len % self.chunk_size)
# Reshape into chunks
x_chunks = self.reshape_into_chunks(hidden_states, pad_size, self.chunk_size)
B_chunks = self.reshape_into_chunks(B, pad_size, self.chunk_size)
C_chunks = self.reshape_into_chunks(C, pad_size, self.chunk_size)
# Process time steps
dt = nn.softplus(dt + self.dt_bias)
dt = mx.clip(dt, self.time_step_min)
# Prepare matrices
A = -mx.exp(self.A_log)
A = A * dt[:, None]
# Process chunks
A_chunks = self.reshape_into_chunks(
mx.broadcast_to(A, (batch_size, seq_len + pad_size, self.num_heads)),
pad_size,
self.chunk_size
)
# Output processing
y = y + x * self.D[None, None, :, None]
y = y.reshape((-1, y.shape[1], self.args.intermediate_size))
y = self.norm(y, z)
y = self.out_proj(y)
return y
def forward_inference(self, u: mx.array, cache=None) -> mx.array:
"""Single token processing during inference."""
assert u.shape[1] == 1, "Inference mode expects single token"
# Compute cumulative sums
A_cumsum = mx.cumsum(A_chunks, axis=-1)
L = mx.exp(self.segment_sum(A_chunks))
batch_size = u.shape[0]
# Use provided cache or create new one
self.cache = cache if cache is not None else Mamba2Cache.get_cache(self.args, batch_size, None)
# Process diagonal blocks
G = mx.einsum('...lhn,...shn->...lsh', C_chunks, B_chunks)
M = G * L[..., None, :]
Y_diag = mx.einsum('...lsh,...sh->...lh', M, x_chunks)
# Process off-diagonal blocks
decay_states = mx.exp(A_cumsum[..., -1:] - A_cumsum)
B_decay = B_chunks * decay_states[..., None]
states = mx.einsum('...shn,...sh->...hn', B_decay, x_chunks)
# Combine results
y = Y_diag + states
# Remove padding if necessary
if pad_size > 0:
y = y[:, :seq_len]
return y, ssm_state
def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array:
batch_size, seq_len, _ = x.shape
# Project input
zxbcdt = self.in_proj(mx.squeeze(u, 1))
parts = mx.split(
zxbcdt,
[
self.args.intermediate_size,
self.args.intermediate_size + 2 * self.args.state_size
],
axis=-1
)
z, xBC = parts[0], parts[1]
dt = zxbcdt[:, -self.args.num_heads:] # Extract dt separately
# Update convolution state and apply
conv_state = self.cache.update_conv_state(xBC)
xBC = mx.sum(
conv_state * mx.transpose(self.conv1d.weight, [1, 0, 2]),
axis=-1
)
if self.args.use_conv_bias:
xBC = xBC + self.conv1d.bias
xBC = mx.sigmoid(xBC) * xBC # SiLU
# Split states and ensure proper shapes
x_splits = mx.split(
xBC,
[self.args.intermediate_size, self.args.state_size],
axis=-1
)
x, B, C = x_splits[0], x_splits[1], x_splits[2]
projected_states = self.in_proj(x.squeeze(1))
# Process time steps - ensure proper broadcasting
dt = mx.reshape(dt, (batch_size, self.args.num_heads))
dt = mx.clip(
nn.softplus(dt + self.dt_bias[None, :]),
self.args.time_step_min,
self.args.time_step_max
)
# Calculate d_mlp based on projection size
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 *
self.n_groups * self.ssm_state_size - self.num_heads) // 2
# SSM step with explicit shapes
A = -mx.exp(self.A_log)
dA = mx.exp(dt * A[None, :]) # Shape: (batch_size, num_heads)
# Split projections with corrected dimensions
splits = [
d_mlp, # z0
d_mlp, # x0
self.intermediate_size, # gate
self.conv_dim, # hidden_states
self.num_heads # dt
]
# Reshape x considering intermediate size
# x shape should be (batch_size * num_heads, head_dim)
x = mx.reshape(x, (batch_size, self.args.num_heads, -1))
assert x.shape[-1] == self.args.head_dim, f"Head dimension mismatch: {x.shape[-1]} vs {self.args.head_dim}"
z0, x0, x1, gate, hidden_states, dt = projected_states.split(splits, axis=-1)
# Reshape B and C for ssm computation
B = mx.reshape(B, (batch_size, -1)) # Should be (batch_size, state_size)
C = mx.reshape(C, (batch_size, -1)) # Should be (batch_size, state_size)
# Split hidden states into components
x_conv, BC = mx.split(hidden_states, [self.intermediate_size], axis=-1)
B, C = mx.split(BC, [self.n_groups * self.ssm_state_size], axis=-1)
# Compute dBx with explicit shapes
dBx = mx.einsum('bh,bs,bhd->bhds', dt, B, x)
# Process based on sequence length
if seq_len > 1 and cache is None:
y, next_state = self.process_long_sequence(
x_conv, B, C, dt,
mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size))
)
else:
# Reshape for single token processing
x_conv = x_conv.reshape(batch_size, -1, self.head_dim)
B = B.reshape(batch_size, self.num_heads, -1)
C = C.reshape(batch_size, self.num_heads, -1)
y, conv_state, next_state = self.process_single_token(x_conv, B, C, dt, cache)
if cache is not None:
cache.update(conv_state, next_state)
ssm_state = self.cache.update_ssm_state(dA, dBx)
y = mx.einsum('bhds,bs->bhd', ssm_state, C)
y = y + x * self.D[None, :, None]
y = mx.reshape(y, (batch_size, self.args.intermediate_size))
# Output processing
y = self.norm(y, z)
y = self.out_proj(y)
return mx.expand_dims(y, 1)
# Apply normalization and final projection
y = self.norm(y) * gate
return self.out_proj(y)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = Mamba2Block(args)
self.mixer = Mamba2Mixer(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache=None) -> mx.array:
def __call__(self, x: mx.array, cache: Optional[Mamba2Cache] = None) -> mx.array:
return self.mixer(self.norm(x), cache) + x
class Mamba2Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
@@ -295,19 +308,20 @@ class Mamba2Model(nn.Module):
def __call__(self, x: mx.array, cache=None) -> mx.array:
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
for layer, layer_cache in zip(self.layers, cache):
x = layer(x, layer_cache)
return self.norm_f(x)
return self.norm_f(x)
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.backbone = Mamba2Model(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
@@ -324,17 +338,24 @@ class Model(nn.Module):
return logits
def make_cache(self, batch_size=1):
return [Mamba2Cache(
batch_size=batch_size,
intermediate_size=self.args.intermediate_size,
state_size=self.args.state_size,
conv_kernel=self.args.conv_kernel,
num_heads=self.args.num_heads,
head_dim=self.args.head_dim
) for _ in range(len(self.backbone.layers))]
return [
Mamba2Cache(
batch_size=batch_size,
conv_dim=self.args.intermediate_size + 2 * self.args.n_groups * self.args.state_size,
kernel_size=self.args.conv_kernel,
num_heads=self.args.num_heads,
head_dim=self.args.head_dim,
state_size=self.args.state_size
)
for _ in range(len(self.backbone.layers))
]
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1)
return weights
return weights
@property
def layers(self):
return self.backbone.layers