quick save

This commit is contained in:
Goekdeniz-Guelmez 2024-10-20 16:11:39 +02:00
parent cd036ccfb5
commit 4ab5139c05
3 changed files with 266 additions and 170 deletions

View File

@ -338,3 +338,30 @@ class MambaCache(_BaseCache):
@state.setter
def state(self, v):
self.cache = v
class Mamba2Cache:
def __init__(self, num_layers):
self.conv_states = [None] * num_layers
self.ssm_states = [None] * num_layers
self.seqlen_offset = 0
def __getitem__(self, idx):
return (self.conv_states[idx], self.ssm_states[idx])
def __setitem__(self, idx, value):
self.conv_states[idx], self.ssm_states[idx] = value
@property
def state(self):
return {
'conv_states': self.conv_states,
'ssm_states': self.ssm_states,
'seqlen_offset': self.seqlen_offset
}
@state.setter
def state(self, v):
self.conv_states = v['conv_states']
self.ssm_states = v['ssm_states']
self.seqlen_offset = v['seqlen_offset']

View File

@ -6,41 +6,37 @@ from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
# python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello how are you."
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "mamba2"
num_heads: int = 128
head_dim: int = 64
vocab_size: int = 32768
hidden_size: int = 4096
state_size: int = 128
num_hidden_layers: int = 64
layer_norm_epsilon: float = 1e-5
pad_token_id: int = 1
bos_token_id: int = 0
eos_token_id: int = 2
expand: int = 2
conv_kernel: int = 4
n_groups: int = 8
use_bias: bool = False
use_conv_bias: bool = True
hidden_act: str = "silu"
initializer_range: float = 0.1
residual_in_fp32: bool = True
time_step_rank: Union[int, str] = "auto"
time_step_min: float = 0.001
time_step_max: float = 0.1
time_step_floor: float = 1e-4
num_heads: int
head_dim: int
vocab_size: int
hidden_size: int
state_size: int
num_hidden_layers: int
layer_norm_epsilon: float
expand: int
conv_kernel: int
n_groups: int
use_bias: bool
use_conv_bias: bool
initializer_range: float
residual_in_fp32: bool
time_step_min: float
time_step_max: float
time_step_floor: float
rescale_prenorm_residual: bool
use_cache: bool
rms_norm: bool
chunk_size: int
tie_word_embeddings: bool
time_step_limit: Tuple[float, float] = field(default_factory=lambda: (0.0, float("inf")))
rescale_prenorm_residual: bool = False
use_cache: bool = True
rms_norm: bool = True
chunk_size: int = 256
tie_word_embeddings: bool = False
time_step_rank: Union[int, str] = "auto"
model_type: str = "mamba2"
def __post_init__(self):
if not hasattr(self, "intermediate_size"):
@ -79,15 +75,24 @@ class MambaRMSNormGated(nn.Module):
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, groups=1, padding=0):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.channels = channels
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups
self.weight = mx.random.normal((self.channels, kernel_size, 1))
self.bias = mx.zeros((channels,)) if bias else None
self.groups = groups if groups is not None else in_channels
# Ensure in_channels and out_channels are the same for depthwise conv
assert in_channels == out_channels, "In and out channels must be the same for depthwise convolution"
# Ensure groups is equal to in_channels for depthwise conv
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
# Initialize weight with shape (out_channels, kernel_size, 1)
self.weight = mx.random.normal((out_channels, kernel_size, 1))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x, cache=None):
B, L, C = x.shape
@ -116,16 +121,17 @@ class Mamba2Mixer(nn.Module):
self.hidden_size = args.hidden_size
self.state_size = args.state_size
self.num_heads = args.num_heads
self.head_dim = args.head_dim
self.head_dim = args.hidden_size // args.num_heads
self.n_groups = args.n_groups
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
self.conv1d = DepthWiseConv1d(
channels=self.conv_dim,
kernel_size=self.conv_kernel_size,
bias=self.args.use_conv_bias,
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=args.use_conv_bias,
kernel_size=args.conv_kernel,
groups=self.conv_dim,
padding=self.conv_kernel_size - 1,
padding=args.conv_kernel - 1
)
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
@ -135,33 +141,35 @@ class Mamba2Mixer(nn.Module):
bias=args.use_bias
)
self.act = nn.SiLU()
self.dt_bias = mx.ones((self.num_heads,))
self.A_log = mx.log(mx.arange(1, self.num_heads + 1))
self.D = mx.ones((self.num_heads,))
self.A_log = mx.zeros(self.num_heads)
self.D = mx.ones(self.num_heads)
self.dt_bias = mx.zeros(self.num_heads)
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
def ssm_step(self, x, state=None):
def ssm_step(self, x, state, dt_proj):
A = -mx.exp(self.A_log)
D = self.D
deltaBC = self.x_proj(x)
delta, B, C = mx.split(
deltaBC,
indices_or_sections=[
self.time_step_rank,
self.time_step_rank + self.ssm_state_size,
],
axis=-1,
)
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None:
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
y = y + D * x
delta = nn.softplus(dt_proj + self.dt_bias)
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
batch_size = B.shape[0]
B = B.reshape(batch_size, self.n_groups, self.state_size)
C = C.reshape(batch_size, -1, self.state_size)
delta = delta.reshape(batch_size, self.num_heads, 1)
A = A.reshape(1, self.num_heads, 1)
if state is None:
new_state = delta * B
else:
new_state = delta * (B + state * mx.exp(delta * A))
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
y = y + D * x[:, :self.num_heads]
return y, new_state
def __call__(self, x, cache):
@ -173,15 +181,28 @@ class Mamba2Mixer(nn.Module):
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
x_t, z_t, dt_proj = mx.split(
xz,
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
axis=-1
)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1])
y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj)
z_t = nn.silu(z_t)
output_t = y_t * z_t
# Element-wise multiplication
output_t = y_t[:, :, None] * z_t[:, None, :]
# Sum across the second dimension to match the intermediate_size
output_t = output_t.sum(axis=1)
output_t = self.out_proj(output_t)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
return output
@ -240,6 +261,9 @@ class Model(nn.Module):
else:
logits = self.lm_head(x)
print(logits)
print(logits.shape)
return logits
def sanitize(self, weights):

View File

@ -2,11 +2,13 @@
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
from typing import Tuple, Union, Optional
import mlx.core as mx
import mlx.nn as nn
import mlx.core as mx
from .base import BaseModelArgs
from .cache import Mamba2Cache
# python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello how are you."
@ -46,22 +48,6 @@ class ModelArgs(BaseModelArgs):
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class Mamba2Cache:
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
@ -75,6 +61,7 @@ class MambaRMSNormGated(nn.Module):
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
@ -111,27 +98,22 @@ class DepthWiseConv1d(nn.Module):
class Mamba2Mixer(nn.Module):
def __init__(self, args: ModelArgs):
def __init__(self, args, layer_idx):
super().__init__()
self.args = args
self.intermediate_size = args.intermediate_size
self.time_step_rank = args.time_step_rank
self.conv_kernel_size = args.conv_kernel
self.layer_idx = layer_idx
self.hidden_size = args.hidden_size
self.state_size = args.state_size
self.intermediate_size = args.intermediate_size
self.num_heads = args.num_heads
self.head_dim = args.hidden_size // args.num_heads
self.head_dim = args.head_dim
self.ssm_state_size = args.state_size
self.n_groups = args.n_groups
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=args.use_conv_bias,
kernel_size=args.conv_kernel,
groups=self.conv_dim,
padding=args.conv_kernel - 1
)
self.conv_kernel_size = args.conv_kernel
self.use_conv_bias = args.use_conv_bias
self.use_bias = args.use_bias
self.time_step_min = args.time_step_min
self.time_step_max = args.time_step_max
self.chunk_size = args.chunk_size
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(
@ -139,91 +121,151 @@ class Mamba2Mixer(nn.Module):
projection_size,
bias=args.use_bias
)
self.dt_bias = mx.ones((self.num_heads,))
self.A_log = mx.log(mx.arange(1, self.num_heads + 1))
self.D = mx.ones((self.num_heads,))
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
def ssm_step(self, x, state, dt_proj):
print(f"ssm_step input shapes - x: {x.shape}, dt_proj: {dt_proj.shape}")
A = -mx.exp(self.A_log)
D = self.D
delta = nn.softplus(dt_proj + self.dt_bias)
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
print(f"ssm_step split shapes - B: {B.shape}, C: {C.shape}")
batch_size = B.shape[0]
B = B.reshape(batch_size, self.n_groups, self.state_size)
C = C.reshape(batch_size, -1, self.state_size)
print(f"After reshape - B: {B.shape}, C: {C.shape}")
delta = delta.reshape(batch_size, self.num_heads, 1)
A = A.reshape(1, self.num_heads, 1)
if state is None:
new_state = delta * B
else:
new_state = delta * (B + state * mx.exp(delta * A))
print(f"Before final computation - new_state: {new_state.shape}, C: {C.shape}")
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
y = y + D * x[:, :self.num_heads]
print(f"ssm_step output shape - y: {y.shape}")
return y, new_state
def __call__(self, x, cache):
B, T, D = x.shape
print(f"__call__ input shape - x: {x.shape}")
if cache is None:
cache = [None, None]
outputs = []
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
print(f"After in_proj shape - xz: {xz.shape}")
x_t, z_t, dt_proj = mx.split(
xz,
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
axis=-1
self.conv1d = nn.Conv1d(
self.conv_dim,
self.conv_dim,
self.conv_kernel_size,
groups=self.conv_dim,
bias=self.use_conv_bias
)
self.act = nn.SiLU()
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(
self.intermediate_size,
self.hidden_size,
bias=self.use_bias
)
print(f"After split shapes - x_t: {x_t.shape}, z_t: {z_t.shape}, dt_proj: {dt_proj.shape}")
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
print(f"Before ssm_step shape - x_t: {x_t.shape}")
y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj)
z_t = nn.silu(z_t)
print(f"After ssm_step shapes - y_t: {y_t.shape}, z_t: {z_t.shape}")
self.A_log = mx.zeros(self.num_heads)
self.D = mx.ones(self.num_heads)
self.dt_bias = mx.zeros(self.num_heads)
# Element-wise multiplication
output_t = y_t[:, :, None] * z_t[:, None, :]
print(f"After multiplication shape - output_t: {output_t.shape}")
def __call__(self, input_states, cache):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
# Sum across the second dimension to match the intermediate_size
output_t = output_t.sum(axis=1)
print(f"After sum shape - output_t: {output_t.shape}")
projected_states = self.in_proj(input_states)
output_t = self.out_proj(output_t)
print(f"After out_proj shape - output_t: {output_t.shape}")
outputs.append(output_t)
# Calculate the sizes of each split
total_size = projected_states.shape[-1]
remaining_size = total_size - self.intermediate_size - self.conv_dim - self.num_heads
d_mlp = remaining_size // 2
sizes = [
d_mlp,
d_mlp,
self.intermediate_size,
self.conv_dim,
self.num_heads
]
output = mx.stack(outputs, axis=1)
print(f"Final output shape: {output.shape}")
return output
# Perform the split operation
split_result = mx.split(projected_states, sizes, axis=-1)
# Print debug information
print(f"Number of split parts: {len(split_result)}")
print(f"Shapes of split parts: {[part.shape for part in split_result]}")
# Flexibly handle the split result
_, _, _, gate, hidden_states, dt = split_result
if cache is not None:
conv_state = cache.conv_states[self.layer_idx]
if conv_state is None:
# Initialize conv_state if it's None
conv_state = mx.zeros((batch_size, 1, self.conv_kernel_size, hidden_states.shape[-1]))
conv_state = mx.roll(conv_state, -1, -2) # Roll along the kernel dimension
# Reshape hidden_states to match conv_state dimensions
hidden_states_reshaped = hidden_states[:, None, None, :]
conv_state = mx.concat([conv_state[:, :, :-1, :], hidden_states_reshaped], axis=-2)
cache.conv_states[self.layer_idx] = conv_state
# Adjust the convolution operation
hidden_states = mx.sum(conv_state * self.conv1d.weight[:, :, None, :], axis=(-2, -1))
if self.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = self.act(hidden_states)[:, None, :]
else:
hidden_states = hidden_states.transpose(0, 2, 1)
hidden_states = self.act(self.conv1d(hidden_states)).transpose(0, 2, 1)
hidden_states, B, C = mx.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], axis=-1)
A = -mx.exp(self.A_log.astype(mx.float32))
dt = nn.softplus(dt + self.dt_bias)
dt = mx.clip(dt, self.time_step_min, self.time_step_max)
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).astype(mx.float32)
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(mx.float32)
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(mx.float32)
B = mx.repeat(B, repeats=self.num_heads // self.n_groups, axis=2)
C = mx.repeat(C, repeats=self.num_heads // self.n_groups, axis=2)
if cache is not None and cache.seqlen_offset > 0:
ssm_state = cache.ssm_states[self.layer_idx]
dA = mx.exp(dt[:, None, :, None] * A[None, :, None, None])
dB = dt[:, None, :, None] * B
dBx = dB * hidden_states[:, :, :, None]
ssm_state = ssm_state * dA + dBx
cache.ssm_states[self.layer_idx] = ssm_state
y = mx.sum(ssm_state * C[:, None, :, :], axis=-1)
D = self.D[None, :, None].expand(self.D.shape[0], self.head_dim)
y = y + hidden_states * D
y = y.reshape(batch_size, -1)[:, None, :]
else:
# Implement chunked computation here (simplified version)
pad_size = self.chunk_size - (seq_len % self.chunk_size)
hidden_states_padded = mx.pad(hidden_states, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
B_padded = mx.pad(B, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
C_padded = mx.pad(C, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
chunks = seq_len // self.chunk_size + (1 if pad_size > 0 else 0)
y_list = []
ssm_state = mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size))
for i in range(chunks):
chunk_start = i * self.chunk_size
chunk_end = (i + 1) * self.chunk_size
chunk_h = hidden_states_padded[:, chunk_start:chunk_end]
chunk_B = B_padded[:, chunk_start:chunk_end]
chunk_C = C_padded[:, chunk_start:chunk_end]
chunk_dt = dt[:, chunk_start:chunk_end]
dA = mx.exp(chunk_dt[:, :, None, None] * A[None, None, :, None])
dB = chunk_dt[:, :, None, None] * chunk_B
dBx = dB * chunk_h[:, :, :, None]
chunk_y = mx.zeros_like(chunk_h)
for j in range(self.chunk_size):
ssm_state = ssm_state * dA[:, j] + dBx[:, j]
chunk_y[:, j] = mx.sum(ssm_state * chunk_C[:, j], axis=-1)
y_list.append(chunk_y)
y = mx.concat(y_list, axis=1)
if pad_size > 0:
y = y[:, :seq_len]
D = self.D[None, :, None].expand(self.D.shape[0], self.head_dim)
y = y + hidden_states * D
y = y.reshape(batch_size, seq_len, -1)
y = self.norm(y, gate)
contextualized_states = self.out_proj(y.astype(dtype))
return contextualized_states
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.mixer = Mamba2Mixer(args)
self.mixer = Mamba2Mixer(args, layer_idx)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
@ -235,7 +277,7 @@ class Mamba2(nn.Module):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [Mamba2Block(args) for idx in range(args.num_hidden_layers)]
self.layers = [Mamba2Block(args, idx) for idx in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(
@ -274,6 +316,9 @@ class Model(nn.Module):
else:
logits = self.lm_head(x)
print(logits)
print(logits.shape)
return logits
def sanitize(self, weights):
@ -282,8 +327,8 @@ class Model(nn.Module):
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self, batch_size: int = 1):
return [Mamba2Cache() for _ in range(len(self.layers))]
def make_cache(self):
return [Mamba2Cache(self.args.num_hidden_layers) for _ in range(len(self.layers))]
@property
def layers(self):