fixed inference slowness but it cant handle multible Token inputs and is generateing gibberish

This commit is contained in:
Goekdeniz-Guelmez 2024-11-10 16:35:07 +01:00
parent 800b60239c
commit 3a499f9735
2 changed files with 158 additions and 151 deletions

View File

@ -350,28 +350,11 @@ class MambaCache:
return [self.cache[1]] return [self.cache[1]]
class Mamba2Cache:
class Mamba2Cache(_BaseCache): def __init__(self):
def __init__( self.conv_states = [None] # Initialize as None, will be set on first use
self, self.ssm_states = [None] # Initialize as None, will be set on first use
batch_size,
conv_kernel
):
self.conv_kernel: mx.array = conv_kernel
self.conv_states: mx.array = [None]
self.ssm_states = [None]
self.seqlen_offset = 0
def reset(self): @property
self.conv_states = None def state(self):
self.ssm_state = None return [self.conv_states[0], self.ssm_states[0]]
def update(self, layer_idx: int, new_conv_state: mx.array, cache_position: mx.array) -> mx.array:
conv_state = self.conv_states[layer_idx]
cache_position = cache_position.clamp(0, self.conv_kernel - 1)
conv_state = conv_state.roll(shifts=-1, dims=-1)
conv_state[:, :, cache_position] = new_conv_state
self.conv_states[layer_idx].zero_()
self.conv_states[layer_idx] += conv_state
return self.conv_states[layer_idx]

View File

@ -1,11 +1,11 @@
import math import math
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Tuple, Union, Optional from typing import Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from .base import BaseModelArgs from .base import BaseModelArgs
from .cache import MambaCache from .cache import Mamba2Cache
@dataclass @dataclass
class ModelArgs(BaseModelArgs): class ModelArgs(BaseModelArgs):
@ -56,186 +56,217 @@ class MambaRMSNormGated(nn.Module):
variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True)
hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states return self.weight * hidden_states
def silu(x): def silu(x):
return x * mx.sigmoid(x) return x * mx.sigmoid(x)
def ssd(x, A, B, C, chunk_size): def ssd(x, A, B, C, chunk_size):
# Not getting used # Replace einsum operations with explicit reshape and matrix multiply
batch, seqlen, nheads, dim = x.shape batch, seqlen, nheads, dim = x.shape
B = mx.expand_dims(B, axis=2) B = mx.expand_dims(B, axis=2)
C = mx.expand_dims(C, axis=2) C = mx.expand_dims(C, axis=2)
state = mx.zeros((batch, nheads, dim, B.shape[-1])) state = mx.zeros((batch, nheads, dim, B.shape[-1]))
outputs = [] outputs = []
for i in range(0, seqlen, chunk_size): for i in range(0, seqlen, chunk_size):
chunk = slice(i, min(i + chunk_size, seqlen)) chunk = slice(i, min(i + chunk_size, seqlen))
dA = mx.exp(mx.expand_dims(A[chunk], axis=0)) dA = mx.exp(mx.expand_dims(A[chunk], axis=0))
# Replace einsum with explicit operations # Replace einsum with explicit operations
x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim] x_chunk = x[:, chunk] # [batch, chunk_size, nheads, dim]
x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size] x_chunk = mx.transpose(x_chunk, [0, 2, 3, 1]) # [batch, nheads, dim, chunk_size]
B_chunk = B[:, chunk] # [batch, chunk_size, state_size] B_chunk = B[:, chunk] # [batch, chunk_size, state_size]
dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size] dBx = mx.matmul(x_chunk, B_chunk) # [batch, nheads, dim, state_size]
state = state * mx.expand_dims(dA, axis=-1) + dBx state = state * mx.expand_dims(dA, axis=-1) + dBx
# Replace einsum with explicit operations # Replace einsum with explicit operations
C_chunk = C[:, chunk] # [batch, chunk_size, state_size] C_chunk = C[:, chunk] # [batch, chunk_size, state_size]
y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size] y = mx.matmul(state, mx.transpose(C_chunk, [0, 2, 1])) # [batch, nheads, dim, chunk_size]
y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim] y = mx.transpose(y, [0, 3, 1, 2]) # [batch, chunk_size, nheads, dim]
outputs.append(y) outputs.append(y)
return mx.concatenate(outputs, axis=1), state return mx.concatenate(outputs, axis=1), state
class DepthWiseConv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.padding = padding
self.groups = groups if groups is not None else in_channels
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
self.weight = mx.random.normal((in_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x: mx.array, cache=None) -> mx.array:
B, L, C = x.shape
K = self.kernel_size
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
if cache is not None:
if isinstance(cache.conv_states[0], type(None)):
cache.conv_states[0] = mx.zeros((B, K-1, C))
x = mx.concatenate([cache.conv_states[0], x], axis=1)
outputs = []
for c in range(C):
# Input prep debug
x_c = x[:, :, c]
x_c = mx.expand_dims(x_c, axis=1)
# Weight prep debug
w_c = self.weight[c]
if w_c.ndim == 2:
w_c = mx.expand_dims(w_c, axis=0)
elif w_c.ndim == 1:
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
y_c = mx.conv_general(
x_c,
w_c,
stride=1,
padding=0
)
if self.bias is not None:
y_c = y_c + self.bias[c]
y_c = mx.squeeze(y_c, axis=1)
outputs.append(y_c)
# Output statistics
y = mx.stack(outputs, axis=-1)
# Cache update debug
if cache is not None:
cache.conv_states[0] = x[:, -K+1:, :] if x.shape[1] >= K else x
return y
class Mamba2Block(nn.Module): class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.args = args self.args = args
self.chunk_size = args.chunk_size
d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads d_in_proj = 2 * args.intermediate_size + 2 * args.state_size + args.num_heads
self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias) self.in_proj = nn.Linear(args.hidden_size, d_in_proj, bias=args.use_bias)
self.conv_dim = args.intermediate_size + 2 * args.state_size conv_dim = args.intermediate_size + 2 * args.state_size
self.conv1d = DepthWiseConv1d(
# Replace DepthWiseConv1d with grouped nn.Conv1d in_channels=conv_dim,
self.conv1d = nn.Conv1d( out_channels=conv_dim,
in_channels=self.conv_dim,
out_channels=self.conv_dim,
kernel_size=args.conv_kernel, kernel_size=args.conv_kernel,
groups=self.conv_dim, # Makes it depthwise groups=conv_dim,
bias=args.use_conv_bias, bias=args.use_conv_bias,
padding=0 # We'll handle padding via cache padding=args.conv_kernel - 1
) )
self.dt_bias = mx.random.normal((args.num_heads,)) * args.initializer_range # Initialize parameters
self.A_log = mx.random.normal((args.num_heads,)) * args.initializer_range self.dt_bias = mx.ones(args.num_heads)
self.D = mx.random.normal((args.num_heads,)) * args.initializer_range A = mx.arange(1, args.num_heads + 1)
self.A_log = mx.log(A)
self.D = mx.ones(args.num_heads)
self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon) self.norm = MambaRMSNormGated(args.intermediate_size, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias) self.out_proj = nn.Linear(args.intermediate_size, args.hidden_size, bias=args.use_bias)
if args.rescale_prenorm_residual: if args.rescale_prenorm_residual:
layer_scale = math.sqrt(1.0 / args.num_hidden_layers) layer_scale = math.sqrt(1.0 / args.num_hidden_layers)
self.out_proj.weight = self.out_proj.weight * layer_scale self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, u: mx.array, cache: Optional[MambaCache] = None): def __call__(self, u: mx.array, cache=None):
batch_size, seq_len, _ = u.shape # Expect input to be shape [batch_size, 1, dim]
pad_size = self.chunk_size - (seq_len % self.chunk_size) batch_size, seq_len, dimension = u.shape
assert seq_len == 1, "Input should be a single token"
# Initialize cache if needed # Initialize cache if needed
if cache is None: if cache.conv_states[0] is None:
cache = MambaCache() conv_dim = self.args.intermediate_size + 2 * self.args.state_size
cache.conv_states[0] = mx.zeros((batch_size, self.args.conv_kernel - 1, conv_dim))
# Initialize states if needed
if cache[0] is None: # conv state if cache.ssm_states[0] is None:
cache[0] = mx.zeros(( cache.ssm_states[0] = mx.zeros((
batch_size,
self.args.conv_kernel - 1,
self.conv_dim
))
if cache[1] is None: # ssm state
cache[1] = mx.zeros((
batch_size, batch_size,
self.args.num_heads, self.args.num_heads,
self.args.head_dim, self.args.head_dim,
self.args.state_size self.args.state_size
)) ))
# Project input # Project input
zxbcdt = self.in_proj(u) zxbcdt = self.in_proj(u)
# Split projections # Split projections
n_heads = self.args.num_heads
z = zxbcdt[:, :, :self.args.intermediate_size] z = zxbcdt[:, :, :self.args.intermediate_size]
xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size] xBC = zxbcdt[:, :, self.args.intermediate_size:self.args.intermediate_size + 2*self.args.state_size + self.args.intermediate_size]
dt = zxbcdt[:, :, -(self.args.num_heads):] dt = zxbcdt[:, :, -(n_heads):]
# Process delta time # Time steps
dt = mx.reshape(dt, (batch_size, seq_len, self.args.num_heads)) dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.squeeze(dt, axis=0) dt = mx.clip(nn.softplus(dt + self.dt_bias), self.args.time_step_min, self.args.time_step_max)
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor) dt = mx.maximum(dt, self.args.time_step_floor)
# Handle convolution caching and padding # Convolution
conv_state = cache[0] xBC = self.conv1d(xBC, cache=cache)
if conv_state is not None:
xBC = mx.concatenate([conv_state, xBC], axis=1)
# Prepare input for conv1d: [B, C, L]
xBC = mx.transpose(xBC, [0, 2, 1])
# Apply convolution
xBC = self.conv1d(xBC)
# Update cache state
cache[0] = mx.transpose(xBC, [0, 2, 1])[:, -self.args.conv_kernel+1:, :]
# Return to [B, L, C] format
xBC = mx.transpose(xBC, [0, 2, 1])
xBC = silu(xBC) xBC = silu(xBC)
# Split conv output # Split states
x = xBC[:, :, :self.args.intermediate_size] x = xBC[:, :, :self.args.intermediate_size]
B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size] B = xBC[:, :, self.args.intermediate_size:self.args.intermediate_size + self.args.state_size]
C = xBC[:, :, -self.args.state_size:] C = xBC[:, :, -self.args.state_size:]
# Reshape for SSM # Reshape for SSM
x = mx.reshape(x, (batch_size, seq_len, self.args.num_heads, self.args.head_dim)) x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.squeeze(x, axis=1)
B = mx.reshape(B, (batch_size, seq_len, self.args.state_size)) B = mx.reshape(B, (batch_size, 1, self.args.state_size))
B = mx.broadcast_to(B, (batch_size, self.args.num_heads, self.args.state_size)) B = mx.broadcast_to(B, (batch_size, n_heads, self.args.state_size))
B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, seq_len, self.args.state_size)) C = mx.reshape(C, (batch_size, 1, self.args.state_size))
C = mx.broadcast_to(C, (batch_size, self.args.num_heads, self.args.state_size)) C = mx.broadcast_to(C, (batch_size, n_heads, self.args.state_size))
C = mx.expand_dims(C, axis=3)
# SSM state update
ssm_state = cache[1] # SSM updates
A = -mx.exp(self.A_log) A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0)) dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
x = mx.expand_dims(x, axis=-1)
dBx = mx.matmul(x, mx.expand_dims(B, axis=-2))
new_ssm_state = ssm_state * mx.expand_dims(dA, -1) + dBx
cache[1] = new_ssm_state
# Output computation
y = mx.matmul(new_ssm_state, mx.expand_dims(C, axis=-1))
y = mx.squeeze(y, axis=-1)
if pad_size > 0:
y = y[:, :seq_len, :, :]
# Final reshape and projections
y = mx.reshape(y, (batch_size, seq_len, -1))
y = self.norm(y + z)
return self.out_proj(y)
# Update state
x = mx.expand_dims(x, axis=3)
dBx = mx.matmul(x, B)
cache.ssm_states[0] = cache.ssm_states[0] * dA + dBx
# Compute output
y = mx.matmul(cache.ssm_states[0], C)
y = mx.squeeze(y, axis=-1)
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = self.norm(y + z)
return self.out_proj(y)
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.residual_in_fp32 = args.residual_in_fp32 self.residual_in_fp32 = args.residual_in_fp32
self.mixer = Mamba2Block(args) self.mixer = Mamba2Block(args)
self.norm = nn.RMSNorm(args.hidden_size) self.norm = nn.RMSNorm(args.hidden_size)
def __call__(self, x: mx.array, cache): def __call__(self, x: mx.array, cache):
if self.residual_in_fp32: if self.residual_in_fp32:
x = x.astype(mx.float32) x = x.astype(mx.float32)
return self.mixer(self.norm(x), cache) + x normed = self.norm(x)
output = self.mixer(normed, cache)
return output + x
class Mamba2(nn.Module): class Mamba2(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
@ -249,9 +280,11 @@ class Mamba2(nn.Module):
x = self.embeddings(x) x = self.embeddings(x)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * len(self.layers)
hidden = x
for layer, c in zip(self.layers, cache): for layer, c in zip(self.layers, cache):
x = layer(x, c) hidden = layer(hidden, c)
return self.norm_f(x) return self.norm_f(hidden)
class Model(nn.Module): class Model(nn.Module):
@ -259,33 +292,24 @@ class Model(nn.Module):
super().__init__() super().__init__()
self.args = args self.args = args
self.model_type = args.model_type self.model_type = args.model_type
self.backbone = Mamba2(args) self.backbone = Mamba2(args)
if not args.tie_word_embeddings: if not args.tie_word_embeddings:
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
def __call__(self, inputs: mx.array, cache=None): def __call__(self, inputs: mx.array, cache=None):
B, T = inputs.shape hidden = self.backbone(inputs, cache)
x = self.backbone(inputs, cache)
if self.args.tie_word_embeddings: if self.args.tie_word_embeddings:
logits = self.backbone.embeddings.as_linear(x) logits = self.backbone.embeddings.as_linear(hidden)
else: else:
logits = self.lm_head(x) logits = self.lm_head(hidden)
return logits return logits
def make_cache(self, batch_size=1): def make_cache(self):
return [MambaCache() for _ in range(len(self.backbone.layers))] return [Mamba2Cache() for _ in range(len(self.layers))]
def sanitize(self, weights):
for k, v in weights.items():
if "conv1d.weight" in k and v.shape[-1] != 1:
weights[k] = v.moveaxis(2, 1)
return weights
@property @property
def layers(self): def layers(self):
return self.backbone.layers return self.backbone.layers