mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
update
This commit is contained in:
parent
758597eaa8
commit
55485b98e8
@ -338,3 +338,24 @@ class MambaCache(_BaseCache):
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
self.cache = v
|
||||
|
||||
|
||||
class Mamba2Cache(_BaseCache):
|
||||
"""Cache for Mamba model inference containing conv cache and SSM state."""
|
||||
conv_cache: Optional[mx.array] = None
|
||||
ssm_state: Optional[mx.array] = None
|
||||
|
||||
def __getitem__(self, idx: int) -> Optional[mx.array]:
|
||||
if idx == 0:
|
||||
return self.conv_cache
|
||||
elif idx == 1:
|
||||
return self.ssm_state
|
||||
raise IndexError("Cache index must be 0 or 1")
|
||||
|
||||
def __setitem__(self, idx: int, value: Optional[mx.array]):
|
||||
if idx == 0:
|
||||
self.conv_cache = value
|
||||
elif idx == 1:
|
||||
self.ssm_state = value
|
||||
else:
|
||||
raise IndexError("Cache index must be 0 or 1")
|
@ -193,6 +193,7 @@ class Mamba2(nn.Module):
|
||||
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||
|
||||
self.norm = RMSNorm(args.d_inner, device=device)
|
||||
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
|
||||
|
||||
|
@ -188,43 +188,52 @@ class Mamba2Block(nn.Module):
|
||||
if cache is not None and self.args.use_cache:
|
||||
return self.step(u, cache)
|
||||
|
||||
# Calculate sizes
|
||||
d_model = self.args.intermediate_size
|
||||
d_state = self.args.state_size
|
||||
n_heads = self.args.num_heads
|
||||
|
||||
# Compute A
|
||||
A = -mx.exp(self.A_log)
|
||||
|
||||
# Project input
|
||||
zxbcdt = self.in_proj(u)
|
||||
|
||||
# Correct splits for z, xBC, dt
|
||||
splits = [
|
||||
self.args.intermediate_size,
|
||||
self.args.intermediate_size + 2 * self.args.state_size,
|
||||
self.args.num_heads,
|
||||
d_model, # z
|
||||
d_model + 2 * d_state, # xBC (delta, B, C concatenated)
|
||||
n_heads # dt
|
||||
]
|
||||
|
||||
z, xBC, dt = mx.split(zxbcdt, splits, axis=-1)
|
||||
|
||||
# Split using cumulative indices
|
||||
z = zxbcdt[:, :, :splits[0]]
|
||||
xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]]
|
||||
dt = zxbcdt[:, :, -splits[2]:]
|
||||
|
||||
# Process dt
|
||||
dt = mx.clip(
|
||||
nn.softplus(dt + self.dt_bias),
|
||||
self.args.time_step_min,
|
||||
self.args.time_step_max
|
||||
)
|
||||
|
||||
dt = mx.maximum(dt, self.args.time_step_floor)
|
||||
|
||||
# Process convolution
|
||||
xBC = silu(self.conv1d(xBC))
|
||||
|
||||
xBC_parts = mx.split(
|
||||
xBC,
|
||||
[self.args.intermediate_size, self.args.state_size, self.args.state_size],
|
||||
axis=-1
|
||||
)
|
||||
# Split convolved xBC into x, B, C
|
||||
x = xBC[:, :, :d_model]
|
||||
B = xBC[:, :, d_model:d_model + d_state]
|
||||
C = xBC[:, :, -d_state:]
|
||||
|
||||
x = xBC_parts[0]
|
||||
B = xBC_parts[1]
|
||||
C = xBC_parts[2]
|
||||
|
||||
# Replace rearrange with reshape and transpose
|
||||
# Reshape for SSM computation
|
||||
b, l, hp = x.shape
|
||||
h = self.args.num_heads
|
||||
p = hp // h
|
||||
x = mx.reshape(x, (b, l, h, p))
|
||||
|
||||
|
||||
# Compute SSM
|
||||
y, ssm_state = ssd(
|
||||
x * mx.expand_dims(dt, -1),
|
||||
A * dt,
|
||||
@ -232,23 +241,34 @@ class Mamba2Block(nn.Module):
|
||||
C,
|
||||
self.args.chunk_size
|
||||
)
|
||||
|
||||
|
||||
# Add skip connection
|
||||
y = y + x * mx.expand_dims(self.D, -1)
|
||||
# Replace rearrange with reshape
|
||||
y = mx.reshape(y, (b, l, h * p))
|
||||
|
||||
# Reshape back
|
||||
y = mx.reshape(y, (b, l, h * p))
|
||||
|
||||
# Apply norm and projection
|
||||
y = self.norm(y + z)
|
||||
y = self.out_proj(y)
|
||||
|
||||
# Update cache if needed
|
||||
if cache is not None and self.args.use_cache:
|
||||
cache[1] = ssm_state
|
||||
|
||||
# Cast if needed
|
||||
if self.args.residual_in_fp32:
|
||||
y = mx.cast(y, mx.float32)
|
||||
y.astype(mx.float32)
|
||||
|
||||
return y
|
||||
|
||||
def step(self, u: mx.array, cache: MambaCache):
|
||||
"""
|
||||
Process single or multiple tokens while maintaining state.
|
||||
Args:
|
||||
u: Input tensor of shape (batch_size, seq_len, hidden_size)
|
||||
cache: MambaCache object containing conv cache and ssm state
|
||||
"""
|
||||
batch_size = u.shape[0]
|
||||
seq_len = u.shape[1]
|
||||
outputs = []
|
||||
|
Loading…
Reference in New Issue
Block a user