quick save

This commit is contained in:
Goekdeniz-Guelmez
2024-10-20 16:11:39 +02:00
parent cd036ccfb5
commit 4ab5139c05
3 changed files with 266 additions and 170 deletions

View File

@@ -2,11 +2,13 @@
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
from typing import Tuple, Union, Optional
import mlx.core as mx
import mlx.nn as nn
import mlx.core as mx
from .base import BaseModelArgs
from .cache import Mamba2Cache
# python -m mlx_lm.generate --model rokyang/mamba2-130m-hf --prompt "hello how are you."
@@ -46,22 +48,6 @@ class ModelArgs(BaseModelArgs):
if self.time_step_rank == "auto":
self.time_step_rank = math.ceil(self.hidden_size / 16)
class Mamba2Cache:
def __init__(self):
self.cache = [None, None]
def __setitem__(self, idx, value):
self.cache[idx] = value
def __getitem__(self, idx):
return self.cache[idx]
@property
def state(self):
return self.cache
class MambaRMSNormGated(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
@@ -75,6 +61,7 @@ class MambaRMSNormGated(nn.Module):
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
@@ -111,27 +98,22 @@ class DepthWiseConv1d(nn.Module):
class Mamba2Mixer(nn.Module):
def __init__(self, args: ModelArgs):
def __init__(self, args, layer_idx):
super().__init__()
self.args = args
self.intermediate_size = args.intermediate_size
self.time_step_rank = args.time_step_rank
self.conv_kernel_size = args.conv_kernel
self.layer_idx = layer_idx
self.hidden_size = args.hidden_size
self.state_size = args.state_size
self.intermediate_size = args.intermediate_size
self.num_heads = args.num_heads
self.head_dim = args.hidden_size // args.num_heads
self.head_dim = args.head_dim
self.ssm_state_size = args.state_size
self.n_groups = args.n_groups
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.state_size
self.conv1d = DepthWiseConv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
bias=args.use_conv_bias,
kernel_size=args.conv_kernel,
groups=self.conv_dim,
padding=args.conv_kernel - 1
)
self.conv_kernel_size = args.conv_kernel
self.use_conv_bias = args.use_conv_bias
self.use_bias = args.use_bias
self.time_step_min = args.time_step_min
self.time_step_max = args.time_step_max
self.chunk_size = args.chunk_size
self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
projection_size = self.intermediate_size + self.conv_dim + self.num_heads
self.in_proj = nn.Linear(
@@ -139,91 +121,151 @@ class Mamba2Mixer(nn.Module):
projection_size,
bias=args.use_bias
)
self.dt_bias = mx.ones((self.num_heads,))
self.A_log = mx.log(mx.arange(1, self.num_heads + 1))
self.D = mx.ones((self.num_heads,))
self.conv1d = nn.Conv1d(
self.conv_dim,
self.conv_dim,
self.conv_kernel_size,
groups=self.conv_dim,
bias=self.use_conv_bias
)
self.act = nn.SiLU()
self.norm = MambaRMSNormGated(self.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(
self.intermediate_size,
self.hidden_size,
bias=self.use_bias
)
self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=args.use_bias)
self.A_log = mx.zeros(self.num_heads)
self.D = mx.ones(self.num_heads)
self.dt_bias = mx.zeros(self.num_heads)
def __call__(self, input_states, cache):
batch_size, seq_len, _ = input_states.shape
dtype = input_states.dtype
def ssm_step(self, x, state, dt_proj):
print(f"ssm_step input shapes - x: {x.shape}, dt_proj: {dt_proj.shape}")
A = -mx.exp(self.A_log)
D = self.D
delta = nn.softplus(dt_proj + self.dt_bias)
projected_states = self.in_proj(input_states)
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
print(f"ssm_step split shapes - B: {B.shape}, C: {C.shape}")
# Calculate the sizes of each split
total_size = projected_states.shape[-1]
remaining_size = total_size - self.intermediate_size - self.conv_dim - self.num_heads
d_mlp = remaining_size // 2
sizes = [
d_mlp,
d_mlp,
self.intermediate_size,
self.conv_dim,
self.num_heads
]
batch_size = B.shape[0]
B = B.reshape(batch_size, self.n_groups, self.state_size)
C = C.reshape(batch_size, -1, self.state_size)
print(f"After reshape - B: {B.shape}, C: {C.shape}")
# Perform the split operation
split_result = mx.split(projected_states, sizes, axis=-1)
delta = delta.reshape(batch_size, self.num_heads, 1)
A = A.reshape(1, self.num_heads, 1)
# Print debug information
print(f"Number of split parts: {len(split_result)}")
print(f"Shapes of split parts: {[part.shape for part in split_result]}")
if state is None:
new_state = delta * B
# Flexibly handle the split result
_, _, _, gate, hidden_states, dt = split_result
if cache is not None:
conv_state = cache.conv_states[self.layer_idx]
if conv_state is None:
# Initialize conv_state if it's None
conv_state = mx.zeros((batch_size, 1, self.conv_kernel_size, hidden_states.shape[-1]))
conv_state = mx.roll(conv_state, -1, -2) # Roll along the kernel dimension
# Reshape hidden_states to match conv_state dimensions
hidden_states_reshaped = hidden_states[:, None, None, :]
conv_state = mx.concat([conv_state[:, :, :-1, :], hidden_states_reshaped], axis=-2)
cache.conv_states[self.layer_idx] = conv_state
# Adjust the convolution operation
hidden_states = mx.sum(conv_state * self.conv1d.weight[:, :, None, :], axis=(-2, -1))
if self.use_conv_bias:
hidden_states += self.conv1d.bias
hidden_states = self.act(hidden_states)[:, None, :]
else:
new_state = delta * (B + state * mx.exp(delta * A))
print(f"Before final computation - new_state: {new_state.shape}, C: {C.shape}")
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
y = y + D * x[:, :self.num_heads]
print(f"ssm_step output shape - y: {y.shape}")
return y, new_state
hidden_states = hidden_states.transpose(0, 2, 1)
hidden_states = self.act(self.conv1d(hidden_states)).transpose(0, 2, 1)
def __call__(self, x, cache):
B, T, D = x.shape
print(f"__call__ input shape - x: {x.shape}")
if cache is None:
cache = [None, None]
hidden_states, B, C = mx.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], axis=-1)
outputs = []
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
print(f"After in_proj shape - xz: {xz.shape}")
x_t, z_t, dt_proj = mx.split(
xz,
indices_or_sections=[self.conv_dim, self.conv_dim + self.intermediate_size],
axis=-1
)
print(f"After split shapes - x_t: {x_t.shape}, z_t: {z_t.shape}, dt_proj: {dt_proj.shape}")
A = -mx.exp(self.A_log.astype(mx.float32))
dt = nn.softplus(dt + self.dt_bias)
dt = mx.clip(dt, self.time_step_min, self.time_step_max)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
print(f"Before ssm_step shape - x_t: {x_t.shape}")
y_t, cache[1] = self.ssm_step(x_t, cache[1], dt_proj)
z_t = nn.silu(z_t)
print(f"After ssm_step shapes - y_t: {y_t.shape}, z_t: {z_t.shape}")
# Element-wise multiplication
output_t = y_t[:, :, None] * z_t[:, None, :]
print(f"After multiplication shape - output_t: {output_t.shape}")
# Sum across the second dimension to match the intermediate_size
output_t = output_t.sum(axis=1)
print(f"After sum shape - output_t: {output_t.shape}")
output_t = self.out_proj(output_t)
print(f"After out_proj shape - output_t: {output_t.shape}")
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
print(f"Final output shape: {output.shape}")
return output
hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).astype(mx.float32)
B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(mx.float32)
C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).astype(mx.float32)
B = mx.repeat(B, repeats=self.num_heads // self.n_groups, axis=2)
C = mx.repeat(C, repeats=self.num_heads // self.n_groups, axis=2)
if cache is not None and cache.seqlen_offset > 0:
ssm_state = cache.ssm_states[self.layer_idx]
dA = mx.exp(dt[:, None, :, None] * A[None, :, None, None])
dB = dt[:, None, :, None] * B
dBx = dB * hidden_states[:, :, :, None]
ssm_state = ssm_state * dA + dBx
cache.ssm_states[self.layer_idx] = ssm_state
y = mx.sum(ssm_state * C[:, None, :, :], axis=-1)
D = self.D[None, :, None].expand(self.D.shape[0], self.head_dim)
y = y + hidden_states * D
y = y.reshape(batch_size, -1)[:, None, :]
else:
# Implement chunked computation here (simplified version)
pad_size = self.chunk_size - (seq_len % self.chunk_size)
hidden_states_padded = mx.pad(hidden_states, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
B_padded = mx.pad(B, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
C_padded = mx.pad(C, [(0, 0), (0, pad_size), (0, 0), (0, 0)])
chunks = seq_len // self.chunk_size + (1 if pad_size > 0 else 0)
y_list = []
ssm_state = mx.zeros((batch_size, self.num_heads, self.head_dim, self.ssm_state_size))
for i in range(chunks):
chunk_start = i * self.chunk_size
chunk_end = (i + 1) * self.chunk_size
chunk_h = hidden_states_padded[:, chunk_start:chunk_end]
chunk_B = B_padded[:, chunk_start:chunk_end]
chunk_C = C_padded[:, chunk_start:chunk_end]
chunk_dt = dt[:, chunk_start:chunk_end]
dA = mx.exp(chunk_dt[:, :, None, None] * A[None, None, :, None])
dB = chunk_dt[:, :, None, None] * chunk_B
dBx = dB * chunk_h[:, :, :, None]
chunk_y = mx.zeros_like(chunk_h)
for j in range(self.chunk_size):
ssm_state = ssm_state * dA[:, j] + dBx[:, j]
chunk_y[:, j] = mx.sum(ssm_state * chunk_C[:, j], axis=-1)
y_list.append(chunk_y)
y = mx.concat(y_list, axis=1)
if pad_size > 0:
y = y[:, :seq_len]
D = self.D[None, :, None].expand(self.D.shape[0], self.head_dim)
y = y + hidden_states * D
y = y.reshape(batch_size, seq_len, -1)
y = self.norm(y, gate)
contextualized_states = self.out_proj(y.astype(dtype))
return contextualized_states
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
def __init__(self, args: ModelArgs, layer_idx: int):
super().__init__()
self.mixer = Mamba2Mixer(args)
self.mixer = Mamba2Mixer(args, layer_idx)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
@@ -235,7 +277,7 @@ class Mamba2(nn.Module):
super().__init__()
self.args = args
self.embeddings = nn.Embedding(args.vocab_size, args.hidden_size)
self.layers = [Mamba2Block(args) for idx in range(args.num_hidden_layers)]
self.layers = [Mamba2Block(args, idx) for idx in range(args.num_hidden_layers)]
self.norm_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(
@@ -274,6 +316,9 @@ class Model(nn.Module):
else:
logits = self.lm_head(x)
print(logits)
print(logits.shape)
return logits
def sanitize(self, weights):
@@ -282,8 +327,8 @@ class Model(nn.Module):
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self, batch_size: int = 1):
return [Mamba2Cache() for _ in range(len(self.layers))]
def make_cache(self):
return [Mamba2Cache(self.args.num_hidden_layers) for _ in range(len(self.layers))]
@property
def layers(self):