removed the custom Mamba2Cache adn updated the existing MambaCache but still only one input Token and outputs gibberish

This commit is contained in:
Goekdeniz-Guelmez 2024-11-10 16:57:03 +01:00
parent 49d3f188f8
commit 2f95b361a8
2 changed files with 192 additions and 48 deletions

View File

@ -421,24 +421,24 @@ class RotatingKVCache(_BaseCache):
class MambaCache: class MambaCache:
def __init__(self): def __init__(self):
# cache[0] is conv state, cache[1] is ssm state # [conv_state, ssm_state]
self.cache = [None, None] self.cache = [None, None]
self.offset = 0 self.offset = 0 # Sliding window caching
def __setitem__(self, idx, value): def __setitem__(self, idx, value):
self.cache[idx] = value self.cache[idx] = value
def __getitem__(self, idx): def __getitem__(self, idx):
return self.cache[idx] return self.cache[idx]
@property @property
def state(self): def state(self):
return self.cache return self.cache
@state.setter @state.setter
def state(self, v): def state(self, v):
self.cache = v self.cache = v
@property @property
def conv_states(self): def conv_states(self):
return [self.cache[0]] return [self.cache[0]]
@ -446,13 +446,7 @@ class MambaCache:
@property @property
def ssm_states(self): def ssm_states(self):
return [self.cache[1]] return [self.cache[1]]
def reset(self):
class Mamba2Cache: self.cache = [None, None]
def __init__(self): self.offset = 0
self.conv_states = [None] # Initialize as None, will be set on first use
self.ssm_states = [None] # Initialize as None, will be set on first use
@property
def state(self):
return [self.conv_states[0], self.ssm_states[0]]

View File

@ -5,7 +5,7 @@ import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .cache import Mamba2Cache from .cache import MambaCache
@dataclass @dataclass
class ModelArgs(BaseModelArgs): class ModelArgs(BaseModelArgs):
@ -91,6 +91,168 @@ def ssd(x, A, B, C, chunk_size):
return mx.concatenate(outputs, axis=1), state return mx.concatenate(outputs, axis=1), state
# class DepthWiseConv1d(nn.Module):
# def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
# super().__init__()
# self.in_channels = in_channels
# self.out_channels = out_channels
# self.kernel_size = kernel_size
# self.padding = padding
# self.groups = groups if groups is not None else in_channels
# assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
# assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
# self.weight = mx.random.normal((in_channels, 1, kernel_size))
# self.bias = mx.zeros((out_channels,)) if bias else None
# def __call__(self, x: mx.array, cache=None) -> mx.array:
# B, L, C = x.shape
# K = self.kernel_size
# assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
# if cache is not None:
# if isinstance(cache.conv_states[0], type(None)):
# cache.conv_states[0] = mx.zeros((B, K-1, C))
# x = mx.concatenate([cache.conv_states[0], x], axis=1)
# outputs = []
# for c in range(C):
# # Input prep debug
# x_c = x[:, :, c]
# x_c = mx.expand_dims(x_c, axis=1)
# # Weight prep debug
# w_c = self.weight[c]
# if w_c.ndim == 2:
# w_c = mx.expand_dims(w_c, axis=0)
# elif w_c.ndim == 1:
# w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
# y_c = mx.conv_general(
# x_c,
# w_c,
# stride=1,
# padding=0
# )
# if self.bias is not None:
# y_c = y_c + self.bias[c]
# y_c = mx.squeeze(y_c, axis=1)
# outputs.append(y_c)
# # Output statistics
# y = mx.stack(outputs, axis=-1)
# # Cache update debug
# if cache is not None:
# cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
# return y
# class Mamba2Block(nn.Module):
# def __init__(self, args: ModelArgs):
# super().__init__()
# self.args = args
# d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
# self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
# conv_dim = args.intermediate_size + 2 * args.state_size
# self.conv1d = DepthWiseConv1d(
# in_channels=conv_dim,
# out_channels=conv_dim,
# kernel_size=args.conv_kernel,
# groups=conv_dim,
# bias=args.use_conv_bias,
# padding=args.conv_kernel - 1
# )
# self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
# self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
# self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
# self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
# self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
# if args.rescale_prenorm_residual:
# layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
# self.out_proj.weight = self.out_proj.weight * layer_scale
# def __call__(self, u: mx.array, cache=None):
# # Expect input to be shape [batch_size, 1, dim]
# batch_size, seq_len, dimension = u.shape
# assert seq_len == 1, "Input should be a single token"
# # Initialize cache if needed
# if cache.conv_states[0] is None:
# conv_dim = self.args.intermediate_size + 2 * self.args.state_size
# cache.conv_states[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim))
# if cache.ssm_states[0] is None:
# cache.ssm_states[0] = mx.zeros((
# batch_size,
# self.args.num_heads,
# self.args.head_dim,
# self.args.state_size
# ))
# # Project input
# zxbcdt = self.in_proj(u)
# # Split projections
# n_heads = self.args.num_heads
# z = zxbcdt[:, :, :self.args.intermediate_size]
# xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
# dt = zxbcdt[:, :, -(n_heads):]
# # Time steps
# dt = mx.reshape(dt, (batch_size, n_heads))
# dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max)
# dt = mx.maximum(dt, self.args.time_step_floor)
# # Convolution
# xBC = self.conv1d(xBC, cache=cache)
# xBC = silu(xBC)
# # Split states
# x = xBC[:, :, :self.args.intermediate_size]
# B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
# C = xBC[:, :, -self.args.state_size:]
# # Reshape for SSM
# x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
# x = mx.squeeze(x, axis=1)
# B = mx.reshape(B, (batch_size, 1, self.args.state_size))
# B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size))
# B = mx.expand_dims(B, axis=2)
# C = mx.reshape(C, (batch_size, 1, self.args.state_size))
# C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size))
# C = mx.expand_dims(C, axis=3)
# # SSM updates
# A = -mx.exp(self.A_log)
# dA = mx.exp(dt * mx.expand_dims(A, 0))
# dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
# # Update state
# x = mx.expand_dims(x, axis=3)
# dBx = mx.matmul(x, B)
# cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx
# # Compute output
# y = mx.matmul(cache.ssm_states[0], C)
# y = mx.squeeze(y, axis=-1)
# y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
# y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
# y = self.norm(y + z)
# return self.out_proj(y)
class DepthWiseConv1d(nn.Module): class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__() super().__init__()
@ -113,18 +275,17 @@ class DepthWiseConv1d(nn.Module):
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
if cache is not None: if cache is not None:
if isinstance(cache.conv_states[0], type(None)): # Access conv_state directly from cache[0]
cache.conv_states[0] = mx.zeros((B, K-1, C)) if cache[0] is None:
cache[0] = mx.zeros((B, K-1, C))
x = mx.concatenate([cache.conv_states[0], x], axis=1) x = mx.concatenate([cache[0], x], axis=1)
outputs = [] outputs = []
for c in range(C): for c in range(C):
# Input prep debug
x_c = x[:, :, c] x_c = x[:, :, c]
x_c = mx.expand_dims(x_c, axis=1) x_c = mx.expand_dims(x_c, axis=1)
# Weight prep debug
w_c = self.weight[c] w_c = self.weight[c]
if w_c.ndim == 2: if w_c.ndim == 2:
w_c = mx.expand_dims(w_c, axis=0) w_c = mx.expand_dims(w_c, axis=0)
@ -143,12 +304,11 @@ class DepthWiseConv1d(nn.Module):
y_c = mx.squeeze(y_c, axis=1) y_c = mx.squeeze(y_c, axis=1)
outputs.append(y_c) outputs.append(y_c)
# Output statistics
y = mx.stack(outputs, axis=-1) y = mx.stack(outputs, axis=-1)
# Cache update debug # Update cache directly using cache[0]
if cache is not None: if cache is not None:
cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x cache[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y return y
@ -171,11 +331,9 @@ class Mamba2Block(nn.Module):
padding=args.conv_kernel - 1 padding=args.conv_kernel - 1
) )
# Initialize parameters self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
self.dt_bias = mx.ones(args.num_heads) self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
A = mx.arange(1, args.num_heads + 1) self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
self.A_log = mx.log(A)
self.D = mx.ones(args.num_heads)
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
@ -185,47 +343,40 @@ class Mamba2Block(nn.Module):
self.out_proj.weight = self.out_proj.weight * layer_scale self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache=None): def __call__(self, u: mx.array, cache=None):
# Expect input to be shape [batch_size, 1, dim]
batch_size, seq_len, dimension = u.shape batch_size, seq_len, dimension = u.shape
assert seq_len == 1, "Input should be a single token" assert seq_len == 1, "Input should be a single token"
# Initialize cache if needed # Initialize cache states directly using indices
if cache.conv_states[0] is None: if cache[0] is None: # conv state
conv_dim = self.args.intermediate_size + 2 * self.args.state_size conv_dim = self.args.intermediate_size + 2 * self.args.state_size
cache.conv_states[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim)) cache[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim))
if cache.ssm_states[0] is None: if cache[1] is None: # ssm state
cache.ssm_states[0] = mx.zeros(( cache[1] = mx.zeros((
batch_size, batch_size,
self.args.num_heads, self.args.num_heads,
self.args.head_dim, self.args.head_dim,
self.args.state_size self.args.state_size
)) ))
# Project input
zxbcdt = self.in_proj(u) zxbcdt = self.in_proj(u)
# Split projections
n_heads = self.args.num_heads n_heads = self.args.num_heads
z = zxbcdt[:, :, :self.args.intermediate_size] z = zxbcdt[:, :, :self.args.intermediate_size]
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
dt = zxbcdt[:, :, -(n_heads):] dt = zxbcdt[:, :, -(n_heads):]
# Time steps
dt = mx.reshape(dt, (batch_size, n_heads)) dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max) dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max)
dt = mx.maximum(dt, self.args.time_step_floor) dt = mx.maximum(dt, self.args.time_step_floor)
# Convolution
xBC = self.conv1d(xBC, cache=cache) xBC = self.conv1d(xBC, cache=cache)
xBC = silu(xBC) xBC = silu(xBC)
# Split states
x = xBC[:, :, :self.args.intermediate_size] x = xBC[:, :, :self.args.intermediate_size]
B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
C = xBC[:, :, -self.args.state_size:] C = xBC[:, :, -self.args.state_size:]
# Reshape for SSM
x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim)) x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1) x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, self.args.state_size)) B = mx.reshape(B, (batch_size, 1, self.args.state_size))
@ -235,24 +386,23 @@ class Mamba2Block(nn.Module):
C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size)) C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size))
C = mx.expand_dims(C, axis=3) C = mx.expand_dims(C, axis=3)
# SSM updates
A = -mx.exp(self.A_log) A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0)) dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
# Update state
x = mx.expand_dims(x, axis=3) x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B) dBx = mx.matmul(x, B)
cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx # Update ssm state directly using cache[1]
cache[1] = cache[1] * dA + dBx
# Compute output y = mx.matmul(cache[1], C)
y = mx.matmul(cache.ssm_states[0], C)
y = mx.squeeze(y, axis=-1) y = mx.squeeze(y, axis=-1)
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim)) y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z) y = self.norm(y + z)
return self.out_proj(y) return self.out_proj(y)
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
@ -308,7 +458,7 @@ class Model(nn.Module):
return logits return logits
def make_cache(self): def make_cache(self):
return [Mamba2Cache() for _ in range(len(self.layers))] return [MambaCache() for _ in range(len(self.layers))]
@property @property
def layers(self): def layers(self):