inference works but is hella slow

This commit is contained in:
Goekdeniz-Guelmez 2024-10-22 23:06:06 +02:00
parent 9ab581d678
commit a677638c4b
2 changed files with 61 additions and 135 deletions

View File

@ -342,7 +342,7 @@ class MambaCache(_BaseCache):
class Mamba2Cache(_BaseCache): class Mamba2Cache(_BaseCache):
conv_states: Optional[mx.array] = None conv_states: Optional[mx.array] = None
ssm_states: Optional[mx.array] = None ssm_state: Optional[mx.array] = None
def __getitem__(self, idx: int) -> Optional[mx.array]: def __getitem__(self, idx: int) -> Optional[mx.array]:
if idx == 0: if idx == 0:

View File

@ -103,89 +103,55 @@ class DepthWiseConv1d(nn.Module):
assert in_channels == out_channels, "In and out channels must be same for depthwise convolution" assert in_channels == out_channels, "In and out channels must be same for depthwise convolution"
assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution"
# Weight shape: (channels, 1, kernel_size) to match pretrained weights
self.weight = mx.random.normal((in_channels, 1, kernel_size)) self.weight = mx.random.normal((in_channels, 1, kernel_size))
self.bias = mx.zeros((out_channels,)) if bias else None self.bias = mx.zeros((out_channels,)) if bias else None
def __call__(self, x: mx.array, cache=None, cache_idx: int = 0) -> mx.array: def __call__(self, x: mx.array, cache=None) -> mx.array:
B, L, C = x.shape B, L, C = x.shape
K = self.kernel_size K = self.kernel_size
# Validate input dimensions
assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}"
# Handle padding and caching if cache is not None and 'conv_states' in cache:
if cache is not None: conv_states = cache['conv_states']
conv_states = cache[cache_idx]
if conv_states is not None: if conv_states is not None:
# Validate cache shape
assert conv_states.shape[0] == B, "Cache batch size mismatch" assert conv_states.shape[0] == B, "Cache batch size mismatch"
assert conv_states.shape[2] == C, "Cache channel count mismatch" assert conv_states.shape[2] == C, "Cache channel count mismatch"
x = mx.concatenate([conv_states, x], axis=1) x = mx.concatenate([conv_states, x], axis=1)
L = x.shape[1]
else:
# Add left padding of size (kernel_size - 1)
pad_left = K - 1
x = mx.pad(x, [(0, 0), (pad_left, 0), (0, 0)])
L = x.shape[1]
# Pre-allocate output array if possible
outputs = []
# Process each channel independently
for c in range(C):
# Extract and prepare channel data
x_c = x[:, :, c] # Shape: [B, L]
x_c = mx.expand_dims(x_c, axis=1) # Shape: [B, 1, L]
# Prepare filter weights # Process each channel independently
w_c = self.weight[c] # Get channel weights outputs = []
# Ensure filter is 3D: [depth(1), in_channels(1), kernel_size] for c in range(C):
x_c = x[:, :, c]
x_c = mx.expand_dims(x_c, axis=1)
w_c = self.weight[c]
if w_c.ndim == 2: if w_c.ndim == 2:
w_c = mx.expand_dims(w_c, axis=0) w_c = mx.expand_dims(w_c, axis=0)
elif w_c.ndim == 1: elif w_c.ndim == 1:
w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0) w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0)
# Handle inference mode (single token) # Apply convolution
if L < K: y_c = mx.conv_general(
pad_size = K - L x_c,
x_c = mx.pad(x_c, [(0, 0), (0, 0), (pad_size, 0)]) w_c,
stride=1,
padding=0
)
# Apply 1D convolution if self.bias is not None:
try: y_c = y_c + self.bias[c]
y_c = mx.conv_general(
x_c, outputs.append(mx.squeeze(y_c, axis=1))
w_c,
stride=1,
padding=0 # Padding already handled
)
if self.bias is not None:
y_c = y_c + self.bias[c]
# Remove singleton dimension and add to outputs
outputs.append(mx.squeeze(y_c, axis=1))
except Exception as e:
raise RuntimeError(f"Convolution failed for channel {c}. Shapes: input={x_c.shape}, weight={w_c.shape}") from e
# Stack channel outputs along last dimension y = mx.stack(outputs, axis=-1)
y = mx.stack(outputs, axis=-1) # Shape: [B, L', C]
# Update cache if needed # Update cache
if cache is not None: if cache is not None:
# Store last (kernel_size - 1) tokens or entire input if shorter cache['conv_states'] = x[:, -K+1:, :] if x.shape[1] >= K else x
new_cache = x[:, -(K-1):, :] if L >= K else x
cache[cache_idx] = new_cache
if new_cache.shape != cache[cache_idx].shape:
cache[cache_idx] = new_cache
print(f"Cache updated at index {cache_idx}")
else:
print(f"Skipping cache update at index {cache_idx}, shapes are identical.")
return y return y
class Mamba2Block(nn.Module): class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
@ -217,34 +183,22 @@ class Mamba2Block(nn.Module):
self.out_proj.weight = self.out_proj.weight * layer_scale self.out_proj.weight = self.out_proj.weight * layer_scale
def __call__(self, x: mx.array, cache=None): def __call__(self, x: mx.array, cache=None):
# if cache is not None and self.args.use_cache:
if cache is not None: if cache is not None:
return self.step(x, cache) return self.step(x, cache)
# Calculate sizes # Regular forward pass code remains the same...
d_model = self.args.intermediate_size d_model = self.args.intermediate_size
d_state = self.args.state_size d_state = self.args.state_size
n_heads = self.args.num_heads n_heads = self.args.num_heads
# Compute A
A = -mx.exp(self.A_log) A = -mx.exp(self.A_log)
# Project input
zxbcdt = self.in_proj(x) zxbcdt = self.in_proj(x)
# Correct splits for z, xBC, dt splits = [d_model, d_model + 2 * d_state, n_heads]
splits = [
d_model, # z
d_model + 2 * d_state, # xBC (delta, B, C concatenated)
n_heads # dt
]
# Split using cumulative indices
z = zxbcdt[:, :, :splits[0]] z = zxbcdt[:, :, :splits[0]]
xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]] xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]]
dt = zxbcdt[:, :, -splits[2]:] dt = zxbcdt[:, :, -splits[2]:]
# Process dt
dt = mx.clip( dt = mx.clip(
nn.softplus(dt + self.dt_bias), nn.softplus(dt + self.dt_bias),
self.args.time_step_min, self.args.time_step_min,
@ -252,46 +206,26 @@ class Mamba2Block(nn.Module):
) )
dt = mx.maximum(dt, self.args.time_step_floor) dt = mx.maximum(dt, self.args.time_step_floor)
# Process convolution
xBC = silu(self.conv1d(xBC)) xBC = silu(self.conv1d(xBC))
# Split convolved xBC into x, B, C
x = xBC[:, :, :d_model] x = xBC[:, :, :d_model]
B = xBC[:, :, d_model:d_model + d_state] B = xBC[:, :, d_model:d_model + d_state]
C = xBC[:, :, -d_state:] C = xBC[:, :, -d_state:]
# Reshape for SSM computation
b, l, hp = x.shape b, l, hp = x.shape
h = self.args.num_heads h = self.args.num_heads
p = hp // h p = hp // h
x = mx.reshape(x, (b, l, h, p)) x = mx.reshape(x, (b, l, h, p))
# Compute SSM y, ssm_state = ssd(x * mx.expand_dims(dt, -1), A * dt, B, C, self.args.chunk_size)
y, ssm_state = ssd(
x * mx.expand_dims(dt, -1),
A * dt,
B,
C,
self.args.chunk_size
)
# Add skip connection
y = y + x * mx.expand_dims(self.D, -1) y = y + x * mx.expand_dims(self.D, -1)
# Reshape back
y = mx.reshape(y, (b, l, h * p)) y = mx.reshape(y, (b, l, h * p))
# Apply norm and projection
y = self.norm(y + z) y = self.norm(y + z)
y = self.out_proj(y) y = self.out_proj(y)
# Update cache if needed
if cache is not None and self.args.use_cache:
cache[1] = ssm_state
# Cast if needed
if self.args.residual_in_fp32: if self.args.residual_in_fp32:
y.astype(mx.float32) y = y.astype(mx.float32)
return y return y
@ -300,9 +234,17 @@ class Mamba2Block(nn.Module):
seq_len = u.shape[1] seq_len = u.shape[1]
outputs = [] outputs = []
# Initialize SSM state if needed # Initialize cache if needed
if cache[1] is None: if cache.conv_states is None:
cache[1] = mx.zeros(( conv_dim = self.args.intermediate_size + 2 * self.args.state_size
cache.conv_states = mx.zeros((
batch_size,
self.args.conv_kernel - 1,
conv_dim
))
if cache.ssm_state is None:
cache.ssm_state = mx.zeros((
batch_size, batch_size,
self.args.num_heads, self.args.num_heads,
self.args.head_dim, self.args.head_dim,
@ -310,26 +252,17 @@ class Mamba2Block(nn.Module):
)) ))
for pos in range(seq_len): for pos in range(seq_len):
# Getting stuck here in last position, also cache from pos 0 is the same.
# Get single token
u_t = u[:, pos:pos+1, :] u_t = u[:, pos:pos+1, :]
# Project input
zxbcdt = self.in_proj(u_t) zxbcdt = self.in_proj(u_t)
# Calculate sizes
d_model = self.args.intermediate_size d_model = self.args.intermediate_size
d_state = self.args.state_size d_state = self.args.state_size
n_heads = self.args.num_heads n_heads = self.args.num_heads
d_head = self.args.head_dim
# Split projected input
# conv_dim = d_model + 2 * d_state (this should match self.conv1d.in_channels)
z = zxbcdt[:, :, :d_model] z = zxbcdt[:, :, :d_model]
xBC = zxbcdt[:, :, d_model:d_model + 2*d_state + d_model] # Include the full conv dimension xBC = zxbcdt[:, :, d_model:d_model + 2*d_state + d_model]
dt = zxbcdt[:, :, -(n_heads):] dt = zxbcdt[:, :, -(n_heads):]
# Process dt
dt = mx.reshape(dt, (batch_size, n_heads)) dt = mx.reshape(dt, (batch_size, n_heads))
dt = mx.clip( dt = mx.clip(
nn.softplus(dt + self.dt_bias), nn.softplus(dt + self.dt_bias),
@ -338,49 +271,43 @@ class Mamba2Block(nn.Module):
) )
dt = mx.maximum(dt, self.args.time_step_floor) dt = mx.maximum(dt, self.args.time_step_floor)
# Process convolution with correct dimensions # Create a temporary cache dictionary for the convolution
xBC = self.conv1d(xBC, cache=cache, cache_idx=0) conv_cache = {'conv_states': cache.conv_states}
xBC = self.conv1d(xBC, cache=conv_cache)
cache.conv_states = conv_cache['conv_states']
xBC = silu(xBC) xBC = silu(xBC)
# Split convolved xBC into x, B, C with correct dimensions
x = xBC[:, :, :d_model] x = xBC[:, :, :d_model]
B = xBC[:, :, d_model:d_model + d_state] B = xBC[:, :, d_model:d_model + d_state]
C = xBC[:, :, -d_state:] C = xBC[:, :, -d_state:]
# Reshape tensors for SSM computation x = mx.reshape(x, (batch_size, 1, n_heads, self.args.head_dim))
x = mx.reshape(x, (batch_size, 1, n_heads, d_head)) x = mx.squeeze(x, axis=1)
x = mx.squeeze(x, axis=1) # (batch, heads, dim)
B = mx.reshape(B, (batch_size, 1, d_state)) B = mx.reshape(B, (batch_size, 1, d_state))
B = mx.broadcast_to(B, (batch_size, n_heads, d_state)) B = mx.broadcast_to(B, (batch_size, n_heads, d_state))
B = mx.expand_dims(B, axis=2) # (batch, heads, 1, state) B = mx.expand_dims(B, axis=2)
C = mx.reshape(C, (batch_size, 1, d_state)) C = mx.reshape(C, (batch_size, 1, d_state))
C = mx.broadcast_to(C, (batch_size, n_heads, d_state)) C = mx.broadcast_to(C, (batch_size, n_heads, d_state))
C = mx.expand_dims(C, axis=3) # (batch, heads, state, 1) C = mx.expand_dims(C, axis=3)
# Compute SSM updates
A = -mx.exp(self.A_log) A = -mx.exp(self.A_log)
dA = mx.exp(dt * mx.expand_dims(A, 0)) dA = mx.exp(dt * mx.expand_dims(A, 0))
dA = mx.expand_dims(mx.expand_dims(dA, -1), -1) # (batch, heads, 1, 1) dA = mx.expand_dims(mx.expand_dims(dA, -1), -1)
# Update state with proper shapes x = mx.expand_dims(x, axis=3)
x = mx.expand_dims(x, axis=3) # (batch, heads, dim, 1) dBx = mx.matmul(x, B)
dBx = mx.matmul(x, B) # (batch, heads, dim, state)
ssm_state = cache[1] cache.ssm_state = cache.ssm_state * dA + dBx
ssm_state = ssm_state * dA + dBx
cache[1] = ssm_state
# Compute output y = mx.matmul(cache.ssm_state, C)
y = mx.matmul(ssm_state, C) # (batch, heads, dim, 1) y = mx.squeeze(y, axis=-1)
y = mx.squeeze(y, axis=-1) # (batch, heads, dim)
# Add skip connection
y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1) y = y + x[:, :, :, 0] * mx.expand_dims(self.D, -1)
# Reshape and process output y = mx.reshape(y, (batch_size, 1, n_heads * self.args.head_dim))
y = mx.reshape(y, (batch_size, 1, n_heads * d_head))
y = self.norm(y + z) y = self.norm(y + z)
y = self.out_proj(y) y = self.out_proj(y)
@ -440,7 +367,6 @@ class Model(nn.Module):
else: else:
logits = self.lm_head(x) logits = self.lm_head(x)
print('ouput')
return logits return logits
def make_cache(self, batch_size=1): def make_cache(self, batch_size=1):