mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 10:41:18 +08:00
update
This commit is contained in:
parent
758597eaa8
commit
55485b98e8
@ -338,3 +338,24 @@ class MambaCache(_BaseCache):
|
|||||||
@state.setter
|
@state.setter
|
||||||
def state(self, v):
|
def state(self, v):
|
||||||
self.cache = v
|
self.cache = v
|
||||||
|
|
||||||
|
|
||||||
|
class Mamba2Cache(_BaseCache):
|
||||||
|
"""Cache for Mamba model inference containing conv cache and SSM state."""
|
||||||
|
conv_cache: Optional[mx.array] = None
|
||||||
|
ssm_state: Optional[mx.array] = None
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> Optional[mx.array]:
|
||||||
|
if idx == 0:
|
||||||
|
return self.conv_cache
|
||||||
|
elif idx == 1:
|
||||||
|
return self.ssm_state
|
||||||
|
raise IndexError("Cache index must be 0 or 1")
|
||||||
|
|
||||||
|
def __setitem__(self, idx: int, value: Optional[mx.array]):
|
||||||
|
if idx == 0:
|
||||||
|
self.conv_cache = value
|
||||||
|
elif idx == 1:
|
||||||
|
self.ssm_state = value
|
||||||
|
else:
|
||||||
|
raise IndexError("Cache index must be 0 or 1")
|
@ -193,6 +193,7 @@ class Mamba2(nn.Module):
|
|||||||
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
|
self.dt_bias = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||||
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
|
self.A_log = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||||
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
|
self.D = nn.Parameter(torch.empty(args.nheads, device=device))
|
||||||
|
|
||||||
self.norm = RMSNorm(args.d_inner, device=device)
|
self.norm = RMSNorm(args.d_inner, device=device)
|
||||||
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
|
self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=False, device=device)
|
||||||
|
|
||||||
|
@ -188,43 +188,52 @@ class Mamba2Block(nn.Module):
|
|||||||
if cache is not None and self.args.use_cache:
|
if cache is not None and self.args.use_cache:
|
||||||
return self.step(u, cache)
|
return self.step(u, cache)
|
||||||
|
|
||||||
|
# Calculate sizes
|
||||||
|
d_model = self.args.intermediate_size
|
||||||
|
d_state = self.args.state_size
|
||||||
|
n_heads = self.args.num_heads
|
||||||
|
|
||||||
|
# Compute A
|
||||||
A = -mx.exp(self.A_log)
|
A = -mx.exp(self.A_log)
|
||||||
|
|
||||||
|
# Project input
|
||||||
zxbcdt = self.in_proj(u)
|
zxbcdt = self.in_proj(u)
|
||||||
|
|
||||||
|
# Correct splits for z, xBC, dt
|
||||||
splits = [
|
splits = [
|
||||||
self.args.intermediate_size,
|
d_model, # z
|
||||||
self.args.intermediate_size + 2 * self.args.state_size,
|
d_model + 2 * d_state, # xBC (delta, B, C concatenated)
|
||||||
self.args.num_heads,
|
n_heads # dt
|
||||||
]
|
]
|
||||||
|
|
||||||
z, xBC, dt = mx.split(zxbcdt, splits, axis=-1)
|
|
||||||
|
|
||||||
|
# Split using cumulative indices
|
||||||
|
z = zxbcdt[:, :, :splits[0]]
|
||||||
|
xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]]
|
||||||
|
dt = zxbcdt[:, :, -splits[2]:]
|
||||||
|
|
||||||
|
# Process dt
|
||||||
dt = mx.clip(
|
dt = mx.clip(
|
||||||
nn.softplus(dt + self.dt_bias),
|
nn.softplus(dt + self.dt_bias),
|
||||||
self.args.time_step_min,
|
self.args.time_step_min,
|
||||||
self.args.time_step_max
|
self.args.time_step_max
|
||||||
)
|
)
|
||||||
|
|
||||||
dt = mx.maximum(dt, self.args.time_step_floor)
|
dt = mx.maximum(dt, self.args.time_step_floor)
|
||||||
|
|
||||||
|
# Process convolution
|
||||||
xBC = silu(self.conv1d(xBC))
|
xBC = silu(self.conv1d(xBC))
|
||||||
|
|
||||||
xBC_parts = mx.split(
|
# Split convolved xBC into x, B, C
|
||||||
xBC,
|
x = xBC[:, :, :d_model]
|
||||||
[self.args.intermediate_size, self.args.state_size, self.args.state_size],
|
B = xBC[:, :, d_model:d_model + d_state]
|
||||||
axis=-1
|
C = xBC[:, :, -d_state:]
|
||||||
)
|
|
||||||
|
|
||||||
x = xBC_parts[0]
|
# Reshape for SSM computation
|
||||||
B = xBC_parts[1]
|
|
||||||
C = xBC_parts[2]
|
|
||||||
|
|
||||||
# Replace rearrange with reshape and transpose
|
|
||||||
b, l, hp = x.shape
|
b, l, hp = x.shape
|
||||||
h = self.args.num_heads
|
h = self.args.num_heads
|
||||||
p = hp // h
|
p = hp // h
|
||||||
x = mx.reshape(x, (b, l, h, p))
|
x = mx.reshape(x, (b, l, h, p))
|
||||||
|
|
||||||
|
# Compute SSM
|
||||||
y, ssm_state = ssd(
|
y, ssm_state = ssd(
|
||||||
x * mx.expand_dims(dt, -1),
|
x * mx.expand_dims(dt, -1),
|
||||||
A * dt,
|
A * dt,
|
||||||
@ -232,23 +241,34 @@ class Mamba2Block(nn.Module):
|
|||||||
C,
|
C,
|
||||||
self.args.chunk_size
|
self.args.chunk_size
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add skip connection
|
||||||
y = y + x * mx.expand_dims(self.D, -1)
|
y = y + x * mx.expand_dims(self.D, -1)
|
||||||
# Replace rearrange with reshape
|
|
||||||
y = mx.reshape(y, (b, l, h * p))
|
|
||||||
|
|
||||||
|
# Reshape back
|
||||||
|
y = mx.reshape(y, (b, l, h * p))
|
||||||
|
|
||||||
|
# Apply norm and projection
|
||||||
y = self.norm(y + z)
|
y = self.norm(y + z)
|
||||||
y = self.out_proj(y)
|
y = self.out_proj(y)
|
||||||
|
|
||||||
|
# Update cache if needed
|
||||||
if cache is not None and self.args.use_cache:
|
if cache is not None and self.args.use_cache:
|
||||||
cache[1] = ssm_state
|
cache[1] = ssm_state
|
||||||
|
|
||||||
|
# Cast if needed
|
||||||
if self.args.residual_in_fp32:
|
if self.args.residual_in_fp32:
|
||||||
y = mx.cast(y, mx.float32)
|
y.astype(mx.float32)
|
||||||
|
|
||||||
return y
|
return y
|
||||||
|
|
||||||
def step(self, u: mx.array, cache: MambaCache):
|
def step(self, u: mx.array, cache: MambaCache):
|
||||||
|
"""
|
||||||
|
Process single or multiple tokens while maintaining state.
|
||||||
|
Args:
|
||||||
|
u: Input tensor of shape (batch_size, seq_len, hidden_size)
|
||||||
|
cache: MambaCache object containing conv cache and ssm state
|
||||||
|
"""
|
||||||
batch_size = u.shape[0]
|
batch_size = u.shape[0]
|
||||||
seq_len = u.shape[1]
|
seq_len = u.shape[1]
|
||||||
outputs = []
|
outputs = []
|
||||||
|
Loading…
Reference in New Issue
Block a user