This commit is contained in:
Goekdeniz-Guelmez 2024-10-22 21:23:47 +02:00
parent 758597eaa8
commit 55485b98e8
3 changed files with 63 additions and 21 deletions

View File

@ -338,3 +338,24 @@ class MambaCache(_BaseCache):
@state.setter @state.setter
def state(self, v): def state(self, v):
self.cache = v self.cache = v
class Mamba2Cache(_BaseCache):
"""Cache for Mamba model inference containing conv cache and SSM state."""
conv_cache: Optional[mx.array] = None
ssm_state: Optional[mx.array] = None
def __getitem__(self, idx: int) -> Optional[mx.array]:
if idx == 0:
return self.conv_cache
elif idx == 1:
return self.ssm_state
raise IndexError("Cache index must be 0 or 1")
def __setitem__(self, idx: int, value: Optional[mx.array]):
if idx == 0:
self.conv_cache = value
elif idx == 1:
self.ssm_state = value
else:
raise IndexError("Cache index must be 0 or 1")

View File

@ -193,6 +193,7 @@ class Mamba2(nn.Module):
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device)) self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device)) self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
self.D = nn.Parameter(torch.empty(args.nheads, device=device)) self.D = nn.Parameter(torch.empty(args.nheads, device=device))
self.norm = RMSNorm(args.d_inner, device=device) self.norm = RMSNorm(args.d_inner, device=device)
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device) self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)

View File

@ -188,43 +188,52 @@ class Mamba2Block(nn.Module):
if cache is not None and self.args.use_cache: if cache is not None and self.args.use_cache:
return self.step(u, cache) return self.step(u, cache)
# Calculate sizes
d_model = self.args.intermediate_size
d_state = self.args.state_size
n_heads = self.args.num_heads
# Compute A
A = -mx.exp(self.A_log) A = -mx.exp(self.A_log)
# Project input
zxbcdt = self.in_proj(u) zxbcdt = self.in_proj(u)
# Correct splits for z, xBC, dt
splits = [ splits = [
self.args.intermediate_size, d_model, # z
self.args.intermediate_size + 2 * self.args.state_size, d_model + 2 * d_state, # xBC (delta, B, C concatenated)
self.args.num_heads, n_heads # dt
] ]
z, xBC, dt = mx.split(zxbcdt, splits, axis=-1) # Split using cumulative indices
z = zxbcdt[:, :, :splits[0]]
xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]]
dt = zxbcdt[:, :, -splits[2]:]
# Process dt
dt = mx.clip( dt = mx.clip(
nn.softplus(dt + self.dt_bias), nn.softplus(dt + self.dt_bias),
self.args.time_step_min, self.args.time_step_min,
self.args.time_step_max self.args.time_step_max
) )
dt = mx.maximum(dt, self.args.time_step_floor) dt = mx.maximum(dt, self.args.time_step_floor)
# Process convolution
xBC = silu(self.conv1d(xBC)) xBC = silu(self.conv1d(xBC))
xBC_parts = mx.split( # Split convolved xBC into x, B, C
xBC, x = xBC[:, :, :d_model]
[self.args.intermediate_size, self.args.state_size, self.args.state_size], B = xBC[:, :, d_model:d_model + d_state]
axis=-1 C = xBC[:, :, -d_state:]
)
x = xBC_parts[0] # Reshape for SSM computation
B = xBC_parts[1]
C = xBC_parts[2]
# Replace rearrange with reshape and transpose
b, l, hp = x.shape b, l, hp = x.shape
h = self.args.num_heads h = self.args.num_heads
p = hp // h p = hp // h
x = mx.reshape(x, (b, l, h, p)) x = mx.reshape(x, (b, l, h, p))
# Compute SSM
y, ssm_state = ssd( y, ssm_state = ssd(
x * mx.expand_dims(dt, -1), x * mx.expand_dims(dt, -1),
A * dt, A * dt,
@ -233,22 +242,33 @@ class Mamba2Block(nn.Module):
self.args.chunk_size self.args.chunk_size
) )
# Add skip connection
y = y + x * mx.expand_dims(self.D, -1) y = y + x * mx.expand_dims(self.D, -1)
# Replace rearrange with reshape
# Reshape back
y = mx.reshape(y, (b, l, h * p)) y = mx.reshape(y, (b, l, h * p))
# Apply norm and projection
y = self.norm(y + z) y = self.norm(y + z)
y = self.out_proj(y) y = self.out_proj(y)
# Update cache if needed
if cache is not None and self.args.use_cache: if cache is not None and self.args.use_cache:
cache[1] = ssm_state cache[1] = ssm_state
# Cast if needed
if self.args.residual_in_fp32: if self.args.residual_in_fp32:
y = mx.cast(y, mx.float32) y.astype(mx.float32)
return y return y
def step(self, u: mx.array, cache: MambaCache): def step(self, u: mx.array, cache: MambaCache):
"""
Process single or multiple tokens while maintaining state.
Args:
u: Input tensor of shape (batch_size, seq_len, hidden_size)
cache: MambaCache object containing conv cache and ssm state
"""
batch_size = u.shape[0] batch_size = u.shape[0]
seq_len = u.shape[1] seq_len = u.shape[1]
outputs = [] outputs = []