This commit is contained in:
Goekdeniz-Guelmez
2024-10-22 21:23:47 +02:00
parent 758597eaa8
commit 55485b98e8
3 changed files with 63 additions and 21 deletions

View File

@@ -188,43 +188,52 @@ class Mamba2Block(nn.Module):
if cache is not None and self.args.use_cache:
return self.step(u, cache)
# Calculate sizes
d_model = self.args.intermediate_size
d_state = self.args.state_size
n_heads = self.args.num_heads
# Compute A
A = -mx.exp(self.A_log)
# Project input
zxbcdt = self.in_proj(u)
# Correct splits for z, xBC, dt
splits = [
self.args.intermediate_size,
self.args.intermediate_size + 2 * self.args.state_size,
self.args.num_heads,
d_model, # z
d_model + 2 * d_state, # xBC (delta, B, C concatenated)
n_heads # dt
]
z, xBC, dt = mx.split(zxbcdt, splits, axis=-1)
# Split using cumulative indices
z = zxbcdt[:, :, :splits[0]]
xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]]
dt = zxbcdt[:, :, -splits[2]:]
# Process dt
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)
# Process convolution
xBC = silu(self.conv1d(xBC))
xBC_parts = mx.split(
xBC,
[self.args.intermediate_size, self.args.state_size, self.args.state_size],
axis=-1
)
# Split convolved xBC into x, B, C
x = xBC[:, :, :d_model]
B = xBC[:, :, d_model:d_model + d_state]
C = xBC[:, :, -d_state:]
x = xBC_parts[0]
B = xBC_parts[1]
C = xBC_parts[2]
# Replace rearrange with reshape and transpose
# Reshape for SSM computation
b, l, hp = x.shape
h = self.args.num_heads
p = hp // h
x = mx.reshape(x, (b, l, h, p))
# Compute SSM
y, ssm_state = ssd(
x * mx.expand_dims(dt, -1),
A * dt,
@@ -232,23 +241,34 @@ class Mamba2Block(nn.Module):
C,
self.args.chunk_size
)
# Add skip connection
y = y + x * mx.expand_dims(self.D, -1)
# Replace rearrange with reshape
y = mx.reshape(y, (b, l, h * p))
# Reshape back
y = mx.reshape(y, (b, l, h * p))
# Apply norm and projection
y = self.norm(y + z)
y = self.out_proj(y)
# Update cache if needed
if cache is not None and self.args.use_cache:
cache[1] = ssm_state
# Cast if needed
if self.args.residual_in_fp32:
y = mx.cast(y, mx.float32)
y.astype(mx.float32)
return y
def step(self, u: mx.array, cache: MambaCache):
"""
Process single or multiple tokens while maintaining state.
Args:
u: Input tensor of shape (batch_size, seq_len, hidden_size)
cache: MambaCache object containing conv cache and ssm state
"""
batch_size = u.shape[0]
seq_len = u.shape[1]
outputs = []