This commit is contained in:
Goekdeniz-Guelmez 2024-10-22 21:23:47 +02:00
parent 758597eaa8
commit 55485b98e8
3 changed files with 63 additions and 21 deletions

View File

@ -338,3 +338,24 @@ class MambaCache(_BaseCache):
@state.setter
def state(self, v):
self.cache = v
class Mamba2Cache(_BaseCache):
"""Cache for Mamba model inference containing conv cache and SSM state."""
conv_cache: Optional[mx.array] = None
ssm_state: Optional[mx.array] = None
def __getitem__(self, idx: int) -> Optional[mx.array]:
if idx == 0:
return self.conv_cache
elif idx == 1:
return self.ssm_state
raise IndexError("Cache index must be 0 or 1")
def __setitem__(self, idx: int, value: Optional[mx.array]):
if idx == 0:
self.conv_cache = value
elif idx == 1:
self.ssm_state = value
else:
raise IndexError("Cache index must be 0 or 1")

View File

@ -193,6 +193,7 @@ class Mamba2(nn.Module):
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
self.norm = RMSNorm(args.d_inner, device=device)
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)

View File

@ -188,43 +188,52 @@ class Mamba2Block(nn.Module):
if cache is not None and self.args.use_cache:
return self.step(u, cache)
# Calculate sizes
d_model = self.args.intermediate_size
d_state = self.args.state_size
n_heads = self.args.num_heads
# Compute A
A = -mx.exp(self.A_log)
# Project input
zxbcdt = self.in_proj(u)
# Correct splits for z, xBC, dt
splits = [
self.args.intermediate_size,
self.args.intermediate_size + 2 * self.args.state_size,
self.args.num_heads,
d_model, # z
d_model + 2 * d_state, # xBC (delta, B, C concatenated)
n_heads # dt
]
z, xBC, dt = mx.split(zxbcdt, splits, axis=-1)
# Split using cumulative indices
z = zxbcdt[:, :, :splits[0]]
xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]]
dt = zxbcdt[:, :, -splits[2]:]
# Process dt
dt = mx.clip(
nn.softplus(dt + self.dt_bias),
self.args.time_step_min,
self.args.time_step_max
)
dt = mx.maximum(dt, self.args.time_step_floor)
# Process convolution
xBC = silu(self.conv1d(xBC))
xBC_parts = mx.split(
xBC,
[self.args.intermediate_size, self.args.state_size, self.args.state_size],
axis=-1
)
# Split convolved xBC into x, B, C
x = xBC[:, :, :d_model]
B = xBC[:, :, d_model:d_model + d_state]
C = xBC[:, :, -d_state:]
x = xBC_parts[0]
B = xBC_parts[1]
C = xBC_parts[2]
# Replace rearrange with reshape and transpose
# Reshape for SSM computation
b, l, hp = x.shape
h = self.args.num_heads
p = hp // h
x = mx.reshape(x, (b, l, h, p))
# Compute SSM
y, ssm_state = ssd(
x * mx.expand_dims(dt, -1),
A * dt,
@ -232,23 +241,34 @@ class Mamba2Block(nn.Module):
C,
self.args.chunk_size
)
# Add skip connection
y = y + x * mx.expand_dims(self.D, -1)
# Replace rearrange with reshape
y = mx.reshape(y, (b, l, h * p))
# Reshape back
y = mx.reshape(y, (b, l, h * p))
# Apply norm and projection
y = self.norm(y + z)
y = self.out_proj(y)
# Update cache if needed
if cache is not None and self.args.use_cache:
cache[1] = ssm_state
# Cast if needed
if self.args.residual_in_fp32:
y = mx.cast(y, mx.float32)
y.astype(mx.float32)
return y
def step(self, u: mx.array, cache: MambaCache):
"""
Process single or multiple tokens while maintaining state.
Args:
u: Input tensor of shape (batch_size, seq_len, hidden_size)
cache: MambaCache object containing conv cache and ssm state
"""
batch_size = u.shape[0]
seq_len = u.shape[1]
outputs = []