fixed inference slowness but it cant handle multible Token inputs and is generateing gibberish

This commit is contained in:
Goekdeniz-Guelmez 2024-11-10 16:35:07 +01:00
parent 800b60239c
commit 3a499f9735
2 changed files with 158 additions and 151 deletions

View File

@ -350,28 +350,11 @@ class MambaCache:
return [self.cache[1]]
class Mamba2Cache:
def __init__(self):
self.conv_states = [None] # Initialize as None, will be set on first use
self.ssm_states = [None] # Initialize as None, will be set on first use
class Mamba2Cache(_BaseCache):
def __init__(
self,
batch_size,
conv_kernel
):
self.conv_kernel: mx.array = conv_kernel
self.conv_states: mx.array = [None]
self.ssm_states = [None]
self.seqlen_offset = 0
def reset(self):
self.conv_states = None
self.ssm_state = None
def update(self, layer_idx: int, new_conv_state: mx.array, cache_position: mx.array) -> mx.array:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]
@property
def state(self):
return [self.conv_states[0], self.ssm_states[0]]

View File

@ -1,11 +1,11 @@
import math
from dataclasses import dataclass, field
from typing import Tuple, Union, Optional
from typing import Tuple, Union
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .cache import MambaCache
from .cache import Mamba2Cache
@dataclass
class ModelArgs(BaseModelArgs):
@ -61,9 +61,8 @@ class MambaRMSNormGated(nn.Module):
def silu(x):
return x * mx.sigmoid(x)
def ssd(x, A, B, C, chunk_size):
# Not getting used
# Replace einsum operations with explicit reshape and matrix multiply
batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2)
@ -92,30 +91,91 @@ def ssd(x, A, B, C, chunk_size):
return mx.concatenate(outputs, axis=1), state
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
self.weight = mx.random.normal((in_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x: mx.array, cache=None) -> mx.array:
B, L, C = x.shape
K = self.kernel_size
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
if cache is not None:
if isinstance(cache.conv_states[0], type(None)):
cache.conv_states[0] = mx.zeros((B, K-1, C))
x = mx.concatenate([cache.conv_states[0], x], axis=1)
outputs = []
for c in range(C):
# Input prep debug
x_c = x[:, :, c]
x_c = mx.expand_dims(x_c, axis=1)
# Weight prep debug
w_c = self.weight[c]
if w_c.ndim == 2:
w_c = mx.expand_dims(w_c, axis=0)
elif w_c.ndim == 1:
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=0
)
if self.bias is not None:
y_c = y_c + self.bias[c]
y_c = mx.squeeze(y_c, axis=1)
outputs.append(y_c)
# Output statistics
y = mx.stack(outputs, axis=-1)
# Cache update debug
if cache is not None:
cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.chunk_size = args.chunk_size
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
self.conv_dim = args.intermediate_size + 2 * args.state_size
# Replace DepthWiseConv1d with grouped nn.Conv1d
self.conv1d = nn.Conv1d(
in_channels=self.conv_dim,
out_channels=self.conv_dim,
conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
in_channels=conv_dim,
out_channels=conv_dim,
kernel_size=args.conv_kernel,
groups=self.conv_dim, # Makes it depthwise
groups=conv_dim,
bias=args.use_conv_bias,
padding=0 # We'll handle padding via cache
padding=args.conv_kernel - 1
)
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range
# Initialize parameters
self.dt_bias = mx.ones(args.num_heads)
A = mx.arange(1, args.num_heads + 1)
self.A_log = mx.log(A)
self.D = mx.ones(args.num_heads)
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
@ -124,24 +184,18 @@ class Mamba2Block(nn.Module):
layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache: Optional[MambaCache] = None):
batch_size, seq_len, _ = u.shape
pad_size = self.chunk_size - (seq_len % self.chunk_size)
def __call__(self, u: mx.array, cache=None):
# Expect input to be shape [batch_size, 1, dim]
batch_size, seq_len, dimension = u.shape
assert seq_len == 1, "Input should be a single token"
# Initialize cache if needed
if cache is None:
cache = MambaCache()
if cache.conv_states[0] is None:
conv_dim = self.args.intermediate_size + 2 * self.args.state_size
cache.conv_states[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim))
# Initialize states if needed
if cache[0] is None: # conv state
cache[0] = mx.zeros((
batch_size,
self.args.conv_kernel - 1,
self.conv_dim
))
if cache[1] is None: # ssm state
cache[1] = mx.zeros((
if cache.ssm_states[0] is None:
cache.ssm_states[0] = mx.zeros((
batch_size,
self.args.num_heads,
self.args.head_dim,
@ -152,90 +206,67 @@ class Mamba2Block(nn.Module):
zxbcdt = self.in_proj(u)
# Split projections
n_heads = self.args.num_heads
z = zxbcdt[:, :, :self.args.intermediate_size]
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
dt = zxbcdt[:, :, -(self.args.num_heads):]
dt = zxbcdt[:, :, -(n_heads):]
# Process delta time
dt = mx.reshape(dt, (batch_size, seq_len, self.args.num_heads))
dt = mx.squeeze(dt, axis=0)
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
# Time steps
dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max)
dt = mx.maximum(dt, self.args.time_step_floor)
# Handle convolution caching and padding
conv_state = cache[0]
if conv_state is not None:
xBC = mx.concatenate([conv_state, xBC], axis=1)
# Prepare input for conv1d: [B, C, L]
xBC = mx.transpose(xBC, [0, 2, 1])
# Apply convolution
xBC = self.conv1d(xBC)
# Update cache state
cache[0] = mx.transpose(xBC, [0, 2, 1])[:, -self.args.conv_kernel+1:, :]
# Return to [B, L, C] format
xBC = mx.transpose(xBC, [0, 2, 1])
# Convolution
xBC = self.conv1d(xBC, cache=cache)
xBC = silu(xBC)
# Split conv output
# Split states
x = xBC[:, :, :self.args.intermediate_size]
B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
C = xBC[:, :, -self.args.state_size:]
# Reshape for SSM
x = mx.reshape(x, (batch_size, seq_len, self.args.num_heads, self.args.head_dim))
x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, 1, self.args.state_size))
B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, self.args.state_size))
C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size))
C = mx.expand_dims(C, axis=3)
B = mx.reshape(B, (batch_size, seq_len, self.args.state_size))
B = mx.broadcast_to(B, (batch_size, self.args.num_heads, self.args.state_size))
C = mx.reshape(C, (batch_size, seq_len, self.args.state_size))
C = mx.broadcast_to(C, (batch_size, self.args.num_heads, self.args.state_size))
# SSM state update
ssm_state = cache[1]
# SSM updates
A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x = mx.expand_dims(x, axis=-1)
dBx = mx.matmul(x, mx.expand_dims(B, axis=-2))
# Update state
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx
new_ssm_state = ssm_state * mx.expand_dims(dA, -1) + dBx
cache[1] = new_ssm_state
# Output computation
y = mx.matmul(new_ssm_state, mx.expand_dims(C, axis=-1))
# Compute output
y = mx.matmul(cache.ssm_states[0], C)
y = mx.squeeze(y, axis=-1)
if pad_size > 0:
y = y[:, :seq_len, :, :]
# Final reshape and projections
y = mx.reshape(y, (batch_size, seq_len, -1))
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
return self.out_proj(y)
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache):
if self.residual_in_fp32:
x = x.astype(mx.float32)
return self.mixer(self.norm(x), cache) + x
normed = self.norm(x)
output = self.mixer(normed, cache)
return output + x
class Mamba2(nn.Module):
def __init__(self, args: ModelArgs):
@ -249,9 +280,11 @@ class Mamba2(nn.Module):
x = self.embeddings(x)
if cache is None:
cache = [None] * len(self.layers)
hidden = x
for layer, c in zip(self.layers, cache):
x = layer(x, c)
return self.norm_f(x)
hidden = layer(hidden, c)
return self.norm_f(hidden)
class Model(nn.Module):
@ -259,32 +292,23 @@ class Model(nn.Module):
super().__init__()
self.args = args
self.model_type = args.model_type
self.backbone = Mamba2(args)
if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape
x = self.backbone(inputs, cache)
hidden = self.backbone(inputs, cache)
if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x)
logits = self.backbone.embeddings.as_linear(hidden)
else:
logits = self.lm_head(x)
logits = self.lm_head(hidden)
return logits
def make_cache(self, batch_size=1):
return [MambaCache() for _ in range(len(self.backbone.layers))]
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
return weights
def make_cache(self):
return [Mamba2Cache() for _ in range(len(self.layers))]
@property
def layers(self):