adding debug statements (somehiw generating only goes through the fist MambaMixer block pass)

This commit is contained in:
Goekdeniz-Guelmez 2024-10-16 21:09:30 +02:00
parent 00ba27fe6c
commit 8073cb486c
2 changed files with 163 additions and 151 deletions

View File

@ -2,7 +2,7 @@
import math import math
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Tuple, Union from typing import Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -48,14 +48,18 @@ class ModelArgs(BaseModelArgs):
class Mamba2Cache: class Mamba2Cache:
def __init__(self, num_layers): def __init__(self):
self.cache = [[None, None] for _ in range(num_layers)] self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx): def __getitem__(self, idx):
return self.cache[idx] return self.cache[idx]
def __setitem__(self, idx, value): @property
self.cache[idx] = value def state(self):
return self.cache
class MambaRMSNormGated(nn.Module): class MambaRMSNormGated(nn.Module):
@ -71,6 +75,40 @@ class MambaRMSNormGated(nn.Module):
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states return self.weight * hidden_states
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
# Ensure in_channels and out_channels are the same for depthwise conv
assert in_channels == out_channels, "In and out channels must be the same for depthwise convolution"
# Ensure groups is equal to in_channels for depthwise conv
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
# Initialize weight with shape (out_channels, kernel_size, 1)
self.weight = mx.random.normal((out_channels, kernel_size, 1))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x, cache=None):
B, L, C = x.shape
_, K, _ = self.weight.shape
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=self.groups)
if self.bias is not None:
y = y + self.bias
return y, x[:, -K + 1 :, :]
class Mamba2Mixer(nn.Module): class Mamba2Mixer(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
@ -82,11 +120,11 @@ class Mamba2Mixer(nn.Module):
self.hidden_size = args.hidden_size self.hidden_size = args.hidden_size
self.state_size = args.state_size self.state_size = args.state_size
self.num_heads = args.num_heads self.num_heads = args.num_heads
self.head_dim = args.head_dim self.head_dim = args.hidden_size // args.num_heads
self.n_groups = args.n_groups self.n_groups = args.n_groups
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
self.conv1d = nn.Conv1d( self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim, in_channels=self.conv_dim,
out_channels=self.conv_dim, out_channels=self.conv_dim,
bias=args.use_conv_bias, bias=args.use_conv_bias,
@ -102,7 +140,6 @@ class Mamba2Mixer(nn.Module):
bias=args.use_bias bias=args.use_bias
) )
self.act = nn.SiLU()
self.dt_bias = mx.ones((self.num_heads,)) self.dt_bias = mx.ones((self.num_heads,))
self.A_log = mx.log(mx.arange(1, self.num_heads + 1)) self.A_log = mx.log(mx.arange(1, self.num_heads + 1))
self.D = mx.ones((self.num_heads,)) self.D = mx.ones((self.num_heads,))
@ -111,105 +148,84 @@ class Mamba2Mixer(nn.Module):
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
def ssm_step(self, x, dt, state): def ssm_step(self, x, state, dt_proj):
B, L, C = x.shape A = -mx.exp(self.A_log)
print(f"x shape: {x.shape}") D = self.D
projected_states = self.in_proj(x) delta = nn.softplus(dt_proj + self.dt_bias)
print(f"deltaBC shape: {projected_states.shape}")
d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.state_size - self.num_heads) // 2
gate = projected_states[:, :, 2*d_mlp:2*d_mlp+self.intermediate_size]
conv_state = projected_states[:, :, 2*d_mlp+self.intermediate_size:2*d_mlp+self.intermediate_size+self.conv_dim]
time_step = projected_states[:, :, -self.num_heads:]
print(f"conv_state shape before reshape: {conv_state.shape}")
print(f"self.conv_dim: {self.conv_dim}")
# Reshape and handle the case where L=1 B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
conv_state = conv_state.reshape(B, self.conv_dim, L)
if L == 1:
# If sequence length is 1, we need to pad to apply convolution
conv_state = mx.pad(conv_state, ((0, 0), (0, 0), (0, self.conv_kernel_size - 1)))
conv_out = self.conv1d(conv_state) B = B.reshape(-1, self.n_groups, self.state_size)
C = C.reshape(-1, self.n_groups, self.state_size)
# If we padded, we need to remove the padding if state is None:
if L == 1: new_state = mx.expand_dims(delta, -1) * B
conv_out = conv_out[:, :, :L]
# Reshape back to (B, L, C)
conv_out = conv_out.transpose(0, 2, 1)
x_and_conv_out, B, C = mx.split(
conv_out,
[self.intermediate_size, self.n_groups * self.state_size],
axis=-1
)
dt = nn.softplus(time_step + self.dt_bias)
dt = mx.clip(dt, self.args.time_step_min, self.args.time_step_max)
B = B.reshape(-1, self.num_heads, self.head_dim, self.state_size)
C = C.reshape(-1, self.num_heads, self.head_dim, self.state_size)
dA = mx.exp(dt[:, :, None, None] * A[None, :, None, None])
dB = dt[:, :, None, None] * B
new_state = state * dA + x_and_conv_out[:, :, None, None] * dB
y = mx.sum(new_state * C, axis=-1)
y = y + C[None, :, None] * x_and_conv_out
y = self.norm(y.reshape(-1, self.intermediate_size), gate)
output = self.out_proj(y)
return output, new_state
def __call__(
self,
x: mx.array,
cache = None
):
B, L, _ = x.shape
if cache[0] is not None: # Using cached state
conv_state, ssm_state = cache
x = x[:, -1:]
output, new_ssm_state = self.ssm_step(x, None, ssm_state)
cache[1] = new_ssm_state # Update SSM state in cache
else: else:
conv_state, ssm_state = None, None new_state = mx.expand_dims(delta, -1) * (B + state * mx.exp(mx.expand_dims(delta, -1) * A))
outputs = []
for t in range(L):
x = x[:, t:t+1]
output, ssm_state = self.ssm_step(x, None, ssm_state)
outputs.append(output)
output = mx.concatenate(outputs, axis=1)
cache[1] = ssm_state # Store final SSM state in cache
# Update conv state in cache
new_conv_state = x[:, -self.conv_kernel_size:]
cache[0] = new_conv_state
y = mx.sum(new_state * C, axis=-1)
y = y + D * x[:, :self.num_heads]
return y, new_state
def __call__(self, x, cache):
B, T, D = x.shape
if cache is None:
cache = [None, None]
outputs = []
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
x_t, z_t, dt_proj = mx.split(
xz,
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
axis=-1
)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj)
z_t = nn.silu(z_t)
# Print shapes for debugging
print(f"y_t shape: {y_t.shape}")
print(f"z_t shape: {z_t.shape}")
# Reshape y_t to (B, num_heads, head_dim)
y_t_reshaped = y_t.reshape(B, self.num_heads, -1)
# Reshape z_t to (B, num_heads, intermediate_size // num_heads)
z_t_reshaped = z_t.reshape(B, self.num_heads, -1)
print(f"y_t_reshaped shape: {y_t_reshaped.shape}")
print(f"z_t_reshaped shape: {z_t_reshaped.shape}")
# Element-wise multiplication (broadcasting across the last dimension)
output_t = y_t_reshaped * z_t_reshaped
# Reshape to match the expected input of out_proj
output_t = output_t.reshape(B, -1)
print(f"output_t shape before out_proj: {output_t.shape}")
print(f"out_proj weight shape: {self.out_proj.weight.shape}")
output_t = self.out_proj(output_t)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
return output return output
class Mamba2Block(nn.Module): class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args
self.residual_in_fp32 = args.residual_in_fp32
self.norm = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.mixer = Mamba2Mixer(args) self.mixer = Mamba2Mixer(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__( def __call__(self, x: mx.array, cache):
self, return self.mixer(self.norm(x), cache) + x
inputs: mx.array,
cache=None,
):
h = self.mixer(self.norm(inputs), cache=cache)
r = inputs + h
return r
class Mamba2(nn.Module): class Mamba2(nn.Module):
@ -246,11 +262,7 @@ class Model(nn.Module):
if not args.tie_word_embeddings: if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__( def __call__(self, inputs: mx.array, cache=None):
self,
inputs: mx.array,
cache=None
):
B, T = inputs.shape B, T = inputs.shape
x = self.backbone(inputs, cache) x = self.backbone(inputs, cache)
@ -259,17 +271,18 @@ class Model(nn.Module):
logits = self.backbone.embeddings.as_linear(x) logits = self.backbone.embeddings.as_linear(x)
else: else:
logits = self.lm_head(x) logits = self.lm_head(x)
return logits return logits
def sanitize(self, weights): def sanitize(self, weights):
for k, v in weights.items(): for k, v in weights.items():
if "conv1d.weight" in k and v.ndim == 3: if "conv1d.weight" in k and v.ndim == 3:
weights[k] = v.moveaxis(2, 1) weights[k] = v.moveaxis(2, 1)
return weights return weights
def make_cache(self, batch_size: int = 1): def make_cache(self, batch_size: int = 1):
return Mamba2Cache(len(self.backbone.layers)) return [Mamba2Cache() for _ in range(len(self.layers))]
@property @property
def layers(self): def layers(self):
return self.backbone.layers return self.backbone.layers

View File

@ -6,37 +6,37 @@ from typing import Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
# python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello how are you."
@dataclass @dataclass
class ModelArgs(BaseModelArgs): class ModelArgs(BaseModelArgs):
model_type: str = "mamba2" num_heads: int
num_heads: int = 128 head_dim: int
head_dim: int = 64 vocab_size: int
vocab_size: int = 32768 hidden_size: int
hidden_size: int = 4096 state_size: int
state_size: int = 128 num_hidden_layers: int
num_hidden_layers: int = 64 layer_norm_epsilon: float
layer_norm_epsilon: float = 1e-5 expand: int
expand: int = 2 conv_kernel: int
conv_kernel: int = 4 n_groups: int
n_groups: int = 8 use_bias: bool
use_bias: bool = False use_conv_bias: bool
use_conv_bias: bool = True initializer_range: float
initializer_range: float = 0.1 residual_in_fp32: bool
residual_in_fp32: bool = True time_step_min: float
time_step_rank: Union[int, str] = "auto" time_step_max: float
time_step_min: float = 0.001 time_step_floor: float
time_step_max: float = 0.1 rescale_prenorm_residual: bool
time_step_floor: float = 1e-4 use_cache: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf"))) time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
rescale_prenorm_residual: bool = False time_step_rank: Union[int, str] = "auto"
use_cache: bool = True model_type: str = "mamba2"
rms_norm: bool = True
chunk_size: int = 256
tie_word_embeddings: bool = False
def __post_init__(self): def __post_init__(self):
if not hasattr(self, "intermediate_size"): if not hasattr(self, "intermediate_size"):
@ -149,26 +149,35 @@ class Mamba2Mixer(nn.Module):
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias) self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
def ssm_step(self, x, state, dt_proj): def ssm_step(self, x, state, dt_proj):
print(f"ssm_step input shapes - x: {x.shape}, dt_proj: {dt_proj.shape}")
A = -mx.exp(self.A_log) A = -mx.exp(self.A_log)
D = self.D D = self.D
delta = nn.softplus(dt_proj + self.dt_bias) delta = nn.softplus(dt_proj + self.dt_bias)
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1) B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
print(f"ssm_step split shapes - B: {B.shape}, C: {C.shape}")
B = B.reshape(-1, self.n_groups, self.state_size) B = B.reshape(-1, self.n_groups, self.state_size)
C = C.reshape(-1, self.n_groups, self.state_size) C = C.reshape(-1, self.n_groups, self.state_size)
print(f"After reshape - B: {B.shape}, C: {C.shape}")
delta = delta.reshape(-1, self.num_heads, 1)
A = A.reshape(1, self.num_heads, 1)
if state is None: if state is None:
new_state = mx.expand_dims(delta, -1) * B new_state = delta * B
else: else:
new_state = mx.expand_dims(delta, -1) * (B + state * mx.exp(mx.expand_dims(delta, -1) * A)) new_state = delta * (B + state * mx.exp(delta * A))
print(f"Before final computation - new_state: {new_state.shape}, C: {C.shape}")
y = mx.sum(new_state * C, axis=-1) y = mx.sum(new_state * C, axis=-1)
y = y + D * x[:, :self.num_heads] y = y + D * x[:, :self.num_heads]
print(f"ssm_step output shape - y: {y.shape}")
return y, new_state return y, new_state
def __call__(self, x, cache): def __call__(self, x, cache):
B, T, D = x.shape B, T, D = x.shape
print(f"__call__ input shape - x: {x.shape}")
if cache is None: if cache is None:
cache = [None, None] cache = [None, None]
@ -176,47 +185,37 @@ class Mamba2Mixer(nn.Module):
for t in range(T): for t in range(T):
xt = x[:, t, :] xt = x[:, t, :]
xz = self.in_proj(xt) xz = self.in_proj(xt)
print(f"After in_proj shape - xz: {xz.shape}")
x_t, z_t, dt_proj = mx.split( x_t, z_t, dt_proj = mx.split(
xz, xz,
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size], indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
axis=-1 axis=-1
) )
print(f"After split shapes - x_t: {x_t.shape}, z_t: {z_t.shape}, dt_proj: {dt_proj.shape}")
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1) x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t) x_t = nn.silu(x_t)
print(f"Before ssm_step shape - x_t: {x_t.shape}")
y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj) y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj)
z_t = nn.silu(z_t) z_t = nn.silu(z_t)
print(f"After ssm_step shapes - y_t: {y_t.shape}, z_t: {z_t.shape}")
# Print shapes for debugging
print(f"y_t shape: {y_t.shape}")
print(f"z_t shape: {z_t.shape}")
print(f"self.num_heads: {self.num_heads}")
print(f"self.intermediate_size: {self.intermediate_size}")
print(f"self.head_dim: {self.head_dim}")
# Flexible reshaping
y_t_reshaped = y_t.reshape(B, -1, 1)
z_t_reshaped = z_t.reshape(B, y_t_reshaped.shape[1], -1)
# Print reshaped shapes
print(f"y_t_reshaped shape: {y_t_reshaped.shape}")
print(f"z_t_reshaped shape: {z_t_reshaped.shape}")
# Element-wise multiplication # Element-wise multiplication
output_t = y_t_reshaped * z_t_reshaped output_t = y_t[:, :, None] * z_t[:, None, :]
print(f"After multiplication shape - output_t: {output_t.shape}")
# Reshape to match the expected input of out_proj # Sum across the second dimension to match the intermediate_size
output_t = output_t.reshape(B, self.intermediate_size) output_t = output_t.sum(axis=1)
print(f"After sum shape - output_t: {output_t.shape}")
print(f"output_t shape before out_proj: {output_t.shape}")
print(f"out_proj weight shape: {self.out_proj.weight.shape}")
output_t = self.out_proj(output_t) output_t = self.out_proj(output_t)
print(f"After out_proj shape - output_t: {output_t.shape}")
outputs.append(output_t) outputs.append(output_t)
output = mx.stack(outputs, axis=1) output = mx.stack(outputs, axis=1)
print(f"Final output shape: {output.shape}")
return output return output