generation works but outputs gibberish

This commit is contained in:
Goekdeniz-Guelmez 2024-10-20 18:04:34 +02:00
parent 4ab5139c05
commit ab4cf1d1cf
2 changed files with 102 additions and 208 deletions

View File

@ -338,30 +338,3 @@ class MambaCache(_BaseCache):
@state.setter @state.setter
def state(self, v): def state(self, v):
self.cache = v self.cache = v
class Mamba2Cache:
def __init__(self, num_layers):
self.conv_states = [None] * num_layers
self.ssm_states = [None] * num_layers
self.seqlen_offset = 0
def __getitem__(self, idx):
return (self.conv_states[idx], self.ssm_states[idx])
def __setitem__(self, idx, value):
self.conv_states[idx], self.ssm_states[idx] = value
@property
def state(self):
return {
'conv_states': self.conv_states,
'ssm_states': self.ssm_states,
'seqlen_offset': self.seqlen_offset
}
@state.setter
def state(self, v):
self.conv_states = v['conv_states']
self.ssm_states = v['ssm_states']
self.seqlen_offset = v['seqlen_offset']

View File

@ -1,16 +1,11 @@
# Copyright © 2024 Apple Inc.
import math import math
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple, Union, Optional from typing import Tuple, Union
import mlx.nn as nn
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .cache import Mamba2Cache from .cache import MambaCache
# python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello how are you."
@dataclass @dataclass
class ModelArgs(BaseModelArgs): class ModelArgs(BaseModelArgs):
@ -48,6 +43,7 @@ class ModelArgs(BaseModelArgs):
if self.time_step_rank == "auto": if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16) self.time_step_rank = math.ceil(self.hidden_size / 16)
class MambaRMSNormGated(nn.Module): class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6): def __init__(self, hidden_size, eps=1e-6):
super().__init__() super().__init__()
@ -97,175 +93,108 @@ class DepthWiseConv1d(nn.Module):
return y, x[:, -K + 1 :, :] return y, x[:, -K + 1 :, :]
class Mamba2Mixer(nn.Module): class Mamba2Block(nn.Module):
def __init__(self, args, layer_idx): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.layer_idx = layer_idx self.args = args
self.hidden_size = args.hidden_size
self.intermediate_size = args.intermediate_size self.intermediate_size = args.intermediate_size
self.num_heads = args.num_heads self.time_step_rank = args.time_step_rank
self.head_dim = args.head_dim
self.ssm_state_size = args.state_size
self.n_groups = args.n_groups
self.conv_kernel_size = args.conv_kernel self.conv_kernel_size = args.conv_kernel
self.use_conv_bias = args.use_conv_bias self.hidden_size = args.hidden_size
self.use_bias = args.use_bias self.state_size = args.state_size
self.time_step_min = args.time_step_min self.num_heads = args.num_heads
self.time_step_max = args.time_step_max self.head_dim = args.hidden_size // args.num_heads
self.chunk_size = args.chunk_size self.n_groups = args.n_groups
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
projection_size = self.intermediate_size + self.conv_dim + self.num_heads self.conv_dim = args.intermediate_size + 2 * args.n_groups * args.state_size
self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=args.conv_kernel,
bias=args.use_conv_bias,
groups=self.conv_dim,
padding=args.conv_kernel - 1
)
projection_size = args.intermediate_size + self.conv_dim + args.num_heads
self.in_proj = nn.Linear( self.in_proj = nn.Linear(
self.hidden_size, args.hidden_size,
projection_size, projection_size,
bias=args.use_bias bias=args.use_bias
) )
self.conv1d = nn.Conv1d(
self.conv_dim,
self.conv_dim,
self.conv_kernel_size,
groups=self.conv_dim,
bias=self.use_conv_bias
)
self.act = nn.SiLU() self.act = nn.SiLU()
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear( self.A_log = mx.zeros(args.num_heads)
self.intermediate_size, self.D = mx.ones((args.num_heads,))
self.hidden_size, self.dt_bias = mx.zeros(args.num_heads)
bias=self.use_bias
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
def ssm_step(self, x, state, dt_proj):
A = -mx.exp(self.A_log)
D = self.D
delta = nn.softplus(dt_proj + self.dt_bias)
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
batch_size = B.shape[0]
B = B.reshape(batch_size, self.n_groups, self.state_size)
C = C.reshape(batch_size, -1, self.state_size)
delta = delta.reshape(batch_size, self.num_heads, 1)
A = A.reshape(1, self.num_heads, 1)
if state is None:
new_state = delta * B
else:
new_state = delta * (B + state * mx.exp(delta * A))
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
y = y + D * x[:, :self.num_heads]
return y, new_state
def __call__(self, x, cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
outputs = []
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
x_t, z_t, dt_proj = mx.split(
xz,
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
axis=-1
) )
self.A_log = mx.zeros(self.num_heads) # Use the new DepthWiseConv1d with caching
self.D = mx.ones(self.num_heads) conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
self.dt_bias = mx.zeros(self.num_heads) x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj)
z_t = nn.silu(z_t)
def __call__(self, input_states, cache): # Element-wise multiplication
batch_size, seq_len, _ = input_states.shape output_t = y_t[:, :, None] * z_t[:, None, :]
dtype = input_states.dtype
projected_states = self.in_proj(input_states) # Sum across the second dimension to match the intermediate_size
output_t = output_t.sum(axis=1)
# Calculate the sizes of each split output_t = self.out_proj(output_t)
total_size = projected_states.shape[-1] outputs.append(output_t)
remaining_size = total_size - self.intermediate_size - self.conv_dim - self.num_heads
d_mlp = remaining_size // 2
sizes = [
d_mlp,
d_mlp,
self.intermediate_size,
self.conv_dim,
self.num_heads
]
# Perform the split operation output = mx.stack(outputs, axis=1)
split_result = mx.split(projected_states, sizes, axis=-1) return output
# Print debug information
print(f"Number of split parts: {len(split_result)}")
print(f"Shapes of split parts: {[part.shape for part in split_result]}")
# Flexibly handle the split result
_, _, _, gate, hidden_states, dt = split_result
if cache is not None:
conv_state = cache.conv_states[self.layer_idx]
if conv_state is None:
# Initialize conv_state if it's None
conv_state = mx.zeros((batch_size, 1, self.conv_kernel_size, hidden_states.shape[-1]))
conv_state = mx.roll(conv_state, -1, -2) # Roll along the kernel dimension
# Reshape hidden_states to match conv_state dimensions
hidden_states_reshaped = hidden_states[:, None, None, :]
conv_state = mx.concat([conv_state[:, :, :-1, :], hidden_states_reshaped], axis=-2)
cache.conv_states[self.layer_idx] = conv_state
# Adjust the convolution operation
hidden_states = mx.sum(conv_state * self.conv1d.weight[:, :, None, :], axis=(-2, -1))
if self.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = self.act(hidden_states)[:, None, :]
else:
hidden_states = hidden_states.transpose(0, 2, 1)
hidden_states = self.act(self.conv1d(hidden_states)).transpose(0, 2, 1)
hidden_states, B, C = mx.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], axis=-1)
A = -mx.exp(self.A_log.astype(mx.float32))
dt = nn.softplus(dt + self.dt_bias)
dt = mx.clip(dt, self.time_step_min, self.time_step_max)
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).astype(mx.float32)
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(mx.float32)
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(mx.float32)
B = mx.repeat(B, repeats=self.num_heads // self.n_groups, axis=2)
C = mx.repeat(C, repeats=self.num_heads // self.n_groups, axis=2)
if cache is not None and cache.seqlen_offset > 0:
ssm_state = cache.ssm_states[self.layer_idx]
dA = mx.exp(dt[:, None, :, None] * A[None, :, None, None])
dB = dt[:, None, :, None] * B
dBx = dB * hidden_states[:, :, :, None]
ssm_state = ssm_state * dA + dBx
cache.ssm_states[self.layer_idx] = ssm_state
y = mx.sum(ssm_state * C[:, None, :, :], axis=-1)
D = self.D[None, :, None].expand(self.D.shape[0], self.head_dim)
y = y + hidden_states * D
y = y.reshape(batch_size, -1)[:, None, :]
else:
# Implement chunked computation here (simplified version)
pad_size = self.chunk_size - (seq_len % self.chunk_size)
hidden_states_padded = mx.pad(hidden_states, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
B_padded = mx.pad(B, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
C_padded = mx.pad(C, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
chunks = seq_len // self.chunk_size + (1 if pad_size > 0 else 0)
y_list = []
ssm_state = mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size))
for i in range(chunks):
chunk_start = i * self.chunk_size
chunk_end = (i + 1) * self.chunk_size
chunk_h = hidden_states_padded[:, chunk_start:chunk_end]
chunk_B = B_padded[:, chunk_start:chunk_end]
chunk_C = C_padded[:, chunk_start:chunk_end]
chunk_dt = dt[:, chunk_start:chunk_end]
dA = mx.exp(chunk_dt[:, :, None, None] * A[None, None, :, None])
dB = chunk_dt[:, :, None, None] * chunk_B
dBx = dB * chunk_h[:, :, :, None]
chunk_y = mx.zeros_like(chunk_h)
for j in range(self.chunk_size):
ssm_state = ssm_state * dA[:, j] + dBx[:, j]
chunk_y[:, j] = mx.sum(ssm_state * chunk_C[:, j], axis=-1)
y_list.append(chunk_y)
y = mx.concat(y_list, axis=1)
if pad_size > 0:
y = y[:, :seq_len]
D = self.D[None, :, None].expand(self.D.shape[0], self.head_dim)
y = y + hidden_states * D
y = y.reshape(batch_size, seq_len, -1)
y = self.norm(y, gate)
contextualized_states = self.out_proj(y.astype(dtype))
return contextualized_states
class Mamba2Block(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs, layer_idx: int): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.mixer = Mamba2Mixer(args, layer_idx) self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size) self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache): def __call__(self, x: mx.array, cache):
@ -277,24 +206,16 @@ class Mamba2(nn.Module):
super().__init__() super().__init__()
self.args = args self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size) self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [Mamba2Block(args, idx) for idx in range(args.num_hidden_layers)] self.layers = [ResidualBlock(args) for _ in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon) self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__( def __call__(self, x: mx.array, cache):
self, x = self.embeddings(x)
inputs: mx.array,
cache=None
):
hidden_states = self.embeddings(inputs)
if cache is None: if cache is None:
cache = Mamba2Cache(len(self.layers)) cache = [None] * len(self.layers)
for layer, c in zip(self.layers, cache):
for i, layer in enumerate(self.layers): x = layer(x, c)
hidden_states = layer(hidden_states, cache[i]) return self.norm_f(x)
hidden_states = self.norm_f(hidden_states)
return hidden_states
class Model(nn.Module): class Model(nn.Module):
@ -302,7 +223,10 @@ class Model(nn.Module):
super().__init__() super().__init__()
self.args = args self.args = args
self.model_type = args.model_type self.model_type = args.model_type
self.backbone = Mamba2(args) self.backbone = Mamba2(args)
# self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
if not args.tie_word_embeddings: if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
@ -316,9 +240,6 @@ class Model(nn.Module):
else: else:
logits = self.lm_head(x) logits = self.lm_head(x)
print(logits)
print(logits.shape)
return logits return logits
def sanitize(self, weights): def sanitize(self, weights):
@ -328,7 +249,7 @@ class Model(nn.Module):
return weights return weights
def make_cache(self): def make_cache(self):
return [Mamba2Cache(self.args.num_hidden_layers) for _ in range(len(self.layers))] return [MambaCache() for _ in range(len(self.layers))]
@property @property
def layers(self): def layers(self):