Update MambaBlock, Batched Input Processing, Improved Cache Handling, Pre-computed Constants, Cleaner State Management, Explicit Return Values:. Before: 82.442 tokens-per-sec, after: 129.130 tokens-per-sec.

This commit is contained in:
Goekdeniz-Guelmez 2025-01-20 18:59:16 +01:00
parent db582e4f9e
commit dfd51f16d6

View File

@ -147,28 +147,48 @@ class MambaBlock(nn.Module):
y = y + D * x
return y, new_state
def __call__(self, x, cache):
def _process_sequence(self, x, conv_cache, state_cache):
"""Process a sequence of inputs with cached states"""
B, T, D = x.shape
# Project all tokens at once
xz = self.in_proj(x.reshape(-1, D)).reshape(B, T, -1)
x_t, z_t = xz.split(indices_or_sections=2, axis=-1) # Fixed: using split instead of chunk
# Handle convolution with cache
conv_out, new_conv_cache = self.conv1d(x_t, conv_cache)
x_t = nn.silu(conv_out)
# Pre-compute A matrix
A = -mx.exp(self.A_log)
if cache is None:
cache = [None, None]
# Process sequence with state
outputs = []
current_state = state_cache
for t in range(T):
xt = x[:, t, :]
xz = self.in_proj(xt)
x_t, z_t = xz.split(indices_or_sections=2, axis=1)
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, A, cache[1])
z_t = nn.silu(z_t)
output_t = y_t * z_t
output_t = self.out_proj(output_t)
y_t, current_state = self.ssm_step(x_t[:, t], A, current_state)
z_curr = nn.silu(z_t[:, t])
output_t = self.out_proj(y_t * z_curr)
outputs.append(output_t)
output = mx.stack(outputs, axis=1)
return mx.stack(outputs, axis=1), (new_conv_cache, current_state)
def __call__(self, x, cache):
if cache is None or isinstance(cache, list):
# Handle legacy cache format
conv_cache, state_cache = cache if cache is not None else (None, None)
else:
# Handle MambaCache object
conv_cache, state_cache = cache.state
output, (new_conv_cache, new_state_cache) = self._process_sequence(
x, conv_cache, state_cache
)
# Update cache
if isinstance(cache, MambaCache):
cache[0] = new_conv_cache
cache[1] = new_state_cache
return output