save checkpoint

This commit is contained in:
Goekdeniz-Guelmez
2024-11-10 14:36:26 +01:00
parent 906f972d36
commit 800b60239c
4 changed files with 429 additions and 129 deletions

View File

@@ -1,11 +1,11 @@
import math
from dataclasses import dataclass, field
from typing import Tuple, Union
from typing import Tuple, Union, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import Mamba2Cache
from .cache import MambaCache
@dataclass
class ModelArgs(BaseModelArgs):
@@ -61,8 +61,9 @@ class MambaRMSNormGated(nn.Module):
def silu(x):
return x * mx.sigmoid(x)
def ssd(x, A, B, C, chunk_size):
# Replace einsum operations with explicit reshape and matrix multiply
# Not getting used
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
@@ -91,179 +92,134 @@ def ssd(x, A, B, C, chunk_size):
return mx.concatenate(outputs, axis=1), state
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
# Initialize weight with correct shape [C_out, 1, kernel_size]
self.weight = mx.random.normal((out_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x: mx.array, cache=None) -> mx.array:
B, L, C = x.shape
K = self.kernel_size
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
# Handle caching for sequential processing
if cache is not None and cache.conv_states[0] is not None:
if isinstance(cache.conv_states[0], type(None)):
cache.conv_states[0] = mx.zeros((B, K-1, C))
x = mx.concatenate([cache.conv_states[0], x], axis=1)
# Process each channel independently
outputs = []
for c in range(C):
# Extract and reshape the channel
x_c = x[:, :, c] # [B, L]
x_c = mx.expand_dims(x_c, axis=1) # [B, 1, L]
# Get weight for this channel - already in correct shape [1, 1, K]
w_c = mx.expand_dims(self.weight[c], axis=0) # Ensure [1, 1, K]
# Apply convolution
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=self.padding
)
if self.bias is not None:
y_c = y_c + self.bias[c]
outputs.append(mx.squeeze(y_c, axis=1))
y = mx.stack(outputs, axis=-1)
# Update cache
if cache is not None:
cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.chunk_size = args.chunk_size
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
self.conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
# Replace DepthWiseConv1d with grouped nn.Conv1d
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=args.conv_kernel,
groups=self.conv_dim,
groups=self.conv_dim, # Makes it depthwise
bias=args.use_conv_bias,
padding=args.conv_kernel - 1
padding=0 # We'll handle padding via cache
)
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache=None):
# Expect input shape: [batch_size, 1, hidden_size]
def __call__(self, u: mx.array, cache: Optional[MambaCache] = None):
batch_size, seq_len, _ = u.shape
pad_size = self.chunk_size - (seq_len % self.chunk_size)
# Initialize cache if needed
if cache is None:
cache = MambaCache()
# Initialize states if needed
if cache.conv_states[0] is None:
cache.conv_states[0] = mx.zeros((
if cache[0] is None: # conv state
cache[0] = mx.zeros((
batch_size,
self.args.conv_kernel - 1,
self.conv_dim
))
if cache.ssm_states[0] is None:
cache.ssm_states[0] = mx.zeros((
if cache[1] is None: # ssm state
cache[1] = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
self.args.state_size
))
# Project input
zxbcdt = self.in_proj(u)
# Split projections
z = zxbcdt[:, :, :self.args.intermediate_size]
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
dt = zxbcdt[:, :, -(self.args.num_heads):]
# Process delta time
dt = mx.reshape(dt, (batch_size, seq_len, self.args.num_heads))
dt = mx.squeeze(dt, axis=0) # Remove sequence dimension for single token
dt = mx.squeeze(dt, axis=0)
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)
# Convolution step
xBC = self.conv1d(xBC, cache=cache)
# Handle convolution caching and padding
conv_state = cache[0]
if conv_state is not None:
xBC = mx.concatenate([conv_state, xBC], axis=1)
# Prepare input for conv1d: [B, C, L]
xBC = mx.transpose(xBC, [0, 2, 1])
# Apply convolution
xBC = self.conv1d(xBC)
# Update cache state
cache[0] = mx.transpose(xBC, [0, 2, 1])[:, -self.args.conv_kernel+1:, :]
# Return to [B, L, C] format
xBC = mx.transpose(xBC, [0, 2, 1])
xBC = silu(xBC)
# Split conv output
x = xBC[:, :, :self.args.intermediate_size]
B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
C = xBC[:, :, -self.args.state_size:]
# Reshape for SSM
x = mx.reshape(x, (batch_size, 1, self.args.num_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, self.args.state_size))
x = mx.reshape(x, (batch_size, seq_len, self.args.num_heads, self.args.head_dim))
B = mx.reshape(B, (batch_size, seq_len, self.args.state_size))
B = mx.broadcast_to(B, (batch_size, self.args.num_heads, self.args.state_size))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, self.args.state_size))
C = mx.reshape(C, (batch_size, seq_len, self.args.state_size))
C = mx.broadcast_to(C, (batch_size, self.args.num_heads, self.args.state_size))
C = mx.expand_dims(C, axis=3)
# SSM state update
ssm_state = cache[1]
A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx
x = mx.expand_dims(x, axis=-1)
dBx = mx.matmul(x, mx.expand_dims(B, axis=-2))
new_ssm_state = ssm_state * mx.expand_dims(dA, -1) + dBx
cache[1] = new_ssm_state
# Output computation
y = mx.matmul(cache.ssm_states[0], C)
y = mx.matmul(new_ssm_state, mx.expand_dims(C, axis=-1))
y = mx.squeeze(y, axis=-1)
# y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
if pad_size > 0:
y = y[:, :seq_len, :, :]
# Final reshape and projections
y = mx.reshape(y, (batch_size, 1, self.args.num_heads * self.args.head_dim))
y = mx.reshape(y, (batch_size, seq_len, -1))
y = self.norm(y + z)
return self.out_proj(y)
@@ -322,21 +278,13 @@ class Model(nn.Module):
return logits
def make_cache(self, batch_size=1):
return [Mamba2Cache(batch_size, self.args.conv_kernel) for _ in range(len(self.layers))]
return [MambaCache() for _ in range(len(self.backbone.layers))]
def sanitize(self, weights):
sanitized = {}
for k, v in weights.items():
if "conv1d.weight" in k:
# Ensure weights are in correct shape (channels, 1, kernel_size)
if v.ndim == 2:
v = mx.expand_dims(v, axis=1)
elif v.ndim == 1:
v = mx.expand_dims(mx.expand_dims(v, axis=0), axis=0)
sanitized[k] = v
else:
sanitized[k] = v
return sanitized
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
return weights
@property
def layers(self):