generation works but outputs gibberish

This commit is contained in:
Goekdeniz-Guelmez 2024-10-20 18:04:34 +02:00
parent 4ab5139c05
commit ab4cf1d1cf
2 changed files with 102 additions and 208 deletions

View File

@ -338,30 +338,3 @@ class MambaCache(_BaseCache):
@state.setter
def state(self, v):
self.cache = v
class Mamba2Cache:
def __init__(self, num_layers):
self.conv_states = [None] * num_layers
self.ssm_states = [None] * num_layers
self.seqlen_offset = 0
def __getitem__(self, idx):
return (self.conv_states[idx], self.ssm_states[idx])
def __setitem__(self, idx, value):
self.conv_states[idx], self.ssm_states[idx] = value
@property
def state(self):
return {
'conv_states': self.conv_states,
'ssm_states': self.ssm_states,
'seqlen_offset': self.seqlen_offset
}
@state.setter
def state(self, v):
self.conv_states = v['conv_states']
self.ssm_states = v['ssm_states']
self.seqlen_offset = v['seqlen_offset']

View File

@ -1,16 +1,11 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass, field
from typing import Tuple, Union, Optional
import mlx.nn as nn
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import Mamba2Cache
# python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello how are you."
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
@ -26,7 +21,7 @@ class ModelArgs(BaseModelArgs):
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
@ -48,6 +43,7 @@ class ModelArgs(BaseModelArgs):
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
@ -60,7 +56,7 @@ class MambaRMSNormGated(nn.Module):
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
@ -95,177 +91,110 @@ class DepthWiseConv1d(nn.Module):
y = y + self.bias
return y, x[:, -K + 1 :, :]
class Mamba2Mixer(nn.Module):
def __init__(self, args, layer_idx):
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.layer_idx = layer_idx
self.hidden_size = args.hidden_size
self.args = args
self.intermediate_size = args.intermediate_size
self.num_heads = args.num_heads
self.head_dim = args.head_dim
self.ssm_state_size = args.state_size
self.n_groups = args.n_groups
self.time_step_rank = args.time_step_rank
self.conv_kernel_size = args.conv_kernel
self.use_conv_bias = args.use_conv_bias
self.use_bias = args.use_bias
self.time_step_min = args.time_step_min
self.time_step_max = args.time_step_max
self.chunk_size = args.chunk_size
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
self.hidden_size = args.hidden_size
self.state_size = args.state_size
self.num_heads = args.num_heads
self.head_dim = args.hidden_size // args.num_heads
self.n_groups = args.n_groups
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=args.conv_kernel,
bias=args.use_conv_bias,
groups=self.conv_dim,
padding=args.conv_kernel - 1
)
projection_size = args.intermediate_size + self.conv_dim + args.num_heads
self.in_proj = nn.Linear(
self.hidden_size,
args.hidden_size,
projection_size,
bias=args.use_bias
)
self.conv1d = nn.Conv1d(
self.conv_dim,
self.conv_dim,
self.conv_kernel_size,
groups=self.conv_dim,
bias=self.use_conv_bias
)
self.act = nn.SiLU()
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(
self.intermediate_size,
self.hidden_size,
bias=self.use_bias
)
self.A_log = mx.zeros(self.num_heads)
self.D = mx.ones(self.num_heads)
self.dt_bias = mx.zeros(self.num_heads)
def __call__(self, input_states, cache):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
self.A_log = mx.zeros(args.num_heads)
self.D = mx.ones((args.num_heads,))
self.dt_bias = mx.zeros(args.num_heads)
projected_states = self.in_proj(input_states)
# Calculate the sizes of each split
total_size = projected_states.shape[-1]
remaining_size = total_size - self.intermediate_size - self.conv_dim - self.num_heads
d_mlp = remaining_size // 2
sizes = [
d_mlp,
d_mlp,
self.intermediate_size,
self.conv_dim,
self.num_heads
]
# Perform the split operation
split_result = mx.split(projected_states, sizes, axis=-1)
# Print debug information
print(f"Number of split parts: {len(split_result)}")
print(f"Shapes of split parts: {[part.shape for part in split_result]}")
# Flexibly handle the split result
_, _, _, gate, hidden_states, dt = split_result
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
if cache is not None:
conv_state = cache.conv_states[self.layer_idx]
if conv_state is None:
# Initialize conv_state if it's None
conv_state = mx.zeros((batch_size, 1, self.conv_kernel_size, hidden_states.shape[-1]))
conv_state = mx.roll(conv_state, -1, -2) # Roll along the kernel dimension
# Reshape hidden_states to match conv_state dimensions
hidden_states_reshaped = hidden_states[:, None, None, :]
conv_state = mx.concat([conv_state[:, :, :-1, :], hidden_states_reshaped], axis=-2)
cache.conv_states[self.layer_idx] = conv_state
# Adjust the convolution operation
hidden_states = mx.sum(conv_state * self.conv1d.weight[:, :, None, :], axis=(-2, -1))
if self.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = self.act(hidden_states)[:, None, :]
def ssm_step(self, x, state, dt_proj):
A = -mx.exp(self.A_log)
D = self.D
delta = nn.softplus(dt_proj + self.dt_bias)
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
batch_size = B.shape[0]
B = B.reshape(batch_size, self.n_groups, self.state_size)
C = C.reshape(batch_size, -1, self.state_size)
delta = delta.reshape(batch_size, self.num_heads, 1)
A = A.reshape(1, self.num_heads, 1)
if state is None:
new_state = delta * B
else:
hidden_states = hidden_states.transpose(0, 2, 1)
hidden_states = self.act(self.conv1d(hidden_states)).transpose(0, 2, 1)
new_state = delta * (B + state * mx.exp(delta * A))
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
y = y + D * x[:, :self.num_heads]
return y, new_state
hidden_states, B, C = mx.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], axis=-1)
def __call__(self, x, cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
A = -mx.exp(self.A_log.astype(mx.float32))
dt = nn.softplus(dt + self.dt_bias)
dt = mx.clip(dt, self.time_step_min, self.time_step_max)
outputs = []
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
x_t, z_t, dt_proj = mx.split(
xz,
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
axis=-1
)
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).astype(mx.float32)
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(mx.float32)
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(mx.float32)
B = mx.repeat(B, repeats=self.num_heads // self.n_groups, axis=2)
C = mx.repeat(C, repeats=self.num_heads // self.n_groups, axis=2)
if cache is not None and cache.seqlen_offset > 0:
ssm_state = cache.ssm_states[self.layer_idx]
dA = mx.exp(dt[:, None, :, None] * A[None, :, None, None])
dB = dt[:, None, :, None] * B
dBx = dB * hidden_states[:, :, :, None]
ssm_state = ssm_state * dA + dBx
cache.ssm_states[self.layer_idx] = ssm_state
y = mx.sum(ssm_state * C[:, None, :, :], axis=-1)
D = self.D[None, :, None].expand(self.D.shape[0], self.head_dim)
y = y + hidden_states * D
y = y.reshape(batch_size, -1)[:, None, :]
else:
# Implement chunked computation here (simplified version)
pad_size = self.chunk_size - (seq_len % self.chunk_size)
hidden_states_padded = mx.pad(hidden_states, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
B_padded = mx.pad(B, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
C_padded = mx.pad(C, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
chunks = seq_len // self.chunk_size + (1 if pad_size > 0 else 0)
y_list = []
ssm_state = mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size))
for i in range(chunks):
chunk_start = i * self.chunk_size
chunk_end = (i + 1) * self.chunk_size
chunk_h = hidden_states_padded[:, chunk_start:chunk_end]
chunk_B = B_padded[:, chunk_start:chunk_end]
chunk_C = C_padded[:, chunk_start:chunk_end]
chunk_dt = dt[:, chunk_start:chunk_end]
dA = mx.exp(chunk_dt[:, :, None, None] * A[None, None, :, None])
dB = chunk_dt[:, :, None, None] * chunk_B
dBx = dB * chunk_h[:, :, :, None]
chunk_y = mx.zeros_like(chunk_h)
for j in range(self.chunk_size):
ssm_state = ssm_state * dA[:, j] + dBx[:, j]
chunk_y[:, j] = mx.sum(ssm_state * chunk_C[:, j], axis=-1)
y_list.append(chunk_y)
y = mx.concat(y_list, axis=1)
if pad_size > 0:
y = y[:, :seq_len]
D = self.D[None, :, None].expand(self.D.shape[0], self.head_dim)
y = y + hidden_states * D
y = y.reshape(batch_size, seq_len, -1)
y = self.norm(y, gate)
contextualized_states = self.out_proj(y.astype(dtype))
return contextualized_states
# Use the new DepthWiseConv1d with caching
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj)
z_t = nn.silu(z_t)
# Element-wise multiplication
output_t = y_t[:, :, None] * z_t[:, None, :]
# Sum across the second dimension to match the intermediate_size
output_t = output_t.sum(axis=1)
output_t = self.out_proj(output_t)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
return output
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int):
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.mixer = Mamba2Mixer(args, layer_idx)
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
@ -277,24 +206,16 @@ class Mamba2(nn.Module):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [Mamba2Block(args, idx) for idx in range(args.num_hidden_layers)]
self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(
self,
inputs: mx.array,
cache=None
):
hidden_states = self.embeddings(inputs)
def __call__(self, x: mx.array, cache):
x = self.embeddings(x)
if cache is None:
cache = Mamba2Cache(len(self.layers))
for i, layer in enumerate(self.layers):
hidden_states = layer(hidden_states, cache[i])
hidden_states = self.norm_f(hidden_states)
return hidden_states
cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
class Model(nn.Module):
@ -302,7 +223,10 @@ class Model(nn.Module):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
# self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
@ -316,11 +240,8 @@ class Model(nn.Module):
else:
logits = self.lm_head(x)
print(logits)
print(logits.shape)
return logits
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3:
@ -328,7 +249,7 @@ class Model(nn.Module):
return weights
def make_cache(self):
return [Mamba2Cache(self.args.num_hidden_layers) for _ in range(len(self.layers))]
return [MambaCache() for _ in range(len(self.layers))]
@property
def layers(self):