mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
update
This commit is contained in:
@@ -188,43 +188,52 @@ class Mamba2Block(nn.Module):
|
||||
if cache is not None and self.args.use_cache:
|
||||
return self.step(u, cache)
|
||||
|
||||
# Calculate sizes
|
||||
d_model = self.args.intermediate_size
|
||||
d_state = self.args.state_size
|
||||
n_heads = self.args.num_heads
|
||||
|
||||
# Compute A
|
||||
A = -mx.exp(self.A_log)
|
||||
|
||||
# Project input
|
||||
zxbcdt = self.in_proj(u)
|
||||
|
||||
# Correct splits for z, xBC, dt
|
||||
splits = [
|
||||
self.args.intermediate_size,
|
||||
self.args.intermediate_size + 2 * self.args.state_size,
|
||||
self.args.num_heads,
|
||||
d_model, # z
|
||||
d_model + 2 * d_state, # xBC (delta, B, C concatenated)
|
||||
n_heads # dt
|
||||
]
|
||||
|
||||
z, xBC, dt = mx.split(zxbcdt, splits, axis=-1)
|
||||
|
||||
# Split using cumulative indices
|
||||
z = zxbcdt[:, :, :splits[0]]
|
||||
xBC = zxbcdt[:, :, splits[0]:splits[0] + splits[1]]
|
||||
dt = zxbcdt[:, :, -splits[2]:]
|
||||
|
||||
# Process dt
|
||||
dt = mx.clip(
|
||||
nn.softplus(dt + self.dt_bias),
|
||||
self.args.time_step_min,
|
||||
self.args.time_step_max
|
||||
)
|
||||
|
||||
dt = mx.maximum(dt, self.args.time_step_floor)
|
||||
|
||||
# Process convolution
|
||||
xBC = silu(self.conv1d(xBC))
|
||||
|
||||
xBC_parts = mx.split(
|
||||
xBC,
|
||||
[self.args.intermediate_size, self.args.state_size, self.args.state_size],
|
||||
axis=-1
|
||||
)
|
||||
# Split convolved xBC into x, B, C
|
||||
x = xBC[:, :, :d_model]
|
||||
B = xBC[:, :, d_model:d_model + d_state]
|
||||
C = xBC[:, :, -d_state:]
|
||||
|
||||
x = xBC_parts[0]
|
||||
B = xBC_parts[1]
|
||||
C = xBC_parts[2]
|
||||
|
||||
# Replace rearrange with reshape and transpose
|
||||
# Reshape for SSM computation
|
||||
b, l, hp = x.shape
|
||||
h = self.args.num_heads
|
||||
p = hp // h
|
||||
x = mx.reshape(x, (b, l, h, p))
|
||||
|
||||
|
||||
# Compute SSM
|
||||
y, ssm_state = ssd(
|
||||
x * mx.expand_dims(dt, -1),
|
||||
A * dt,
|
||||
@@ -232,23 +241,34 @@ class Mamba2Block(nn.Module):
|
||||
C,
|
||||
self.args.chunk_size
|
||||
)
|
||||
|
||||
|
||||
# Add skip connection
|
||||
y = y + x * mx.expand_dims(self.D, -1)
|
||||
# Replace rearrange with reshape
|
||||
y = mx.reshape(y, (b, l, h * p))
|
||||
|
||||
# Reshape back
|
||||
y = mx.reshape(y, (b, l, h * p))
|
||||
|
||||
# Apply norm and projection
|
||||
y = self.norm(y + z)
|
||||
y = self.out_proj(y)
|
||||
|
||||
# Update cache if needed
|
||||
if cache is not None and self.args.use_cache:
|
||||
cache[1] = ssm_state
|
||||
|
||||
# Cast if needed
|
||||
if self.args.residual_in_fp32:
|
||||
y = mx.cast(y, mx.float32)
|
||||
y.astype(mx.float32)
|
||||
|
||||
return y
|
||||
|
||||
def step(self, u: mx.array, cache: MambaCache):
|
||||
"""
|
||||
Process single or multiple tokens while maintaining state.
|
||||
Args:
|
||||
u: Input tensor of shape (batch_size, seq_len, hidden_size)
|
||||
cache: MambaCache object containing conv cache and ssm state
|
||||
"""
|
||||
batch_size = u.shape[0]
|
||||
seq_len = u.shape[1]
|
||||
outputs = []
|
||||
|
||||
Reference in New Issue
Block a user