mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 20:25:22 +08:00
codestral inference exxtually works now
This commit is contained in:
parent
5a6ada2df0
commit
a6a92cb91f
@ -161,7 +161,7 @@ class Mamba2Block(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.args = args
|
self.args = args
|
||||||
|
|
||||||
# Dimensions
|
# Same dimensions as before
|
||||||
self.d_model = args.hidden_size
|
self.d_model = args.hidden_size
|
||||||
self.d_state = args.state_size
|
self.d_state = args.state_size
|
||||||
self.d_conv = args.conv_kernel
|
self.d_conv = args.conv_kernel
|
||||||
@ -190,51 +190,46 @@ class Mamba2Block(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Output projections
|
# Output projections
|
||||||
self.norm = MambaRMSNormGated(
|
self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon)
|
||||||
self.d_inner,
|
|
||||||
eps=args.layer_norm_epsilon,
|
|
||||||
norm_before_gate=args.norm_before_gate
|
|
||||||
)
|
|
||||||
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
|
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
|
||||||
|
|
||||||
def __call__(self, u: mx.array, cache=None):
|
def __call__(self, u: mx.array, cache=None):
|
||||||
batch_size, seq_len, _ = u.shape
|
batch_size, seq_len, _ = u.shape
|
||||||
if cache is None:
|
|
||||||
cache = [None, None]
|
|
||||||
|
|
||||||
|
# Get or initialize states from cache
|
||||||
|
if cache is None:
|
||||||
|
cache = [None, None] # [conv_state, ssm_state]
|
||||||
|
conv_state, _ = cache # We ignore ssm_state as it's not used in the parallel version
|
||||||
|
|
||||||
# Project input
|
# Project input
|
||||||
zxBCdt = self.in_proj(u)
|
zxBCdt = self.in_proj(u)
|
||||||
|
|
||||||
# Split projections
|
# Split projections
|
||||||
z, xBC, dt = mx.split(
|
z, xBC, dt = mx.split(
|
||||||
zxBCdt,
|
zxBCdt,
|
||||||
[self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state],
|
[self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state],
|
||||||
axis=-1
|
axis=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Process convolution
|
# Process convolution
|
||||||
xBC, conv_state = self.conv1d(xBC, cache[0])
|
xBC, conv_state = self.conv1d(xBC, conv_state)
|
||||||
xBC = silu(xBC)
|
xBC = silu(xBC)
|
||||||
if cache is not None:
|
|
||||||
cache[0] = conv_state
|
|
||||||
xBC = xBC[:, :seq_len, :]
|
xBC = xBC[:, :seq_len, :]
|
||||||
|
|
||||||
# Split and reshape conv output
|
# Split conv output
|
||||||
x, B, C = mx.split(
|
x, B, C = mx.split(
|
||||||
xBC,
|
xBC,
|
||||||
[self.d_inner, self.d_inner + self.d_state * self.n_groups],
|
[self.d_inner, self.d_inner + self.d_state * self.n_groups],
|
||||||
axis=-1
|
axis=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape for SSM processing
|
# Reshape for SSM processing
|
||||||
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
|
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
|
||||||
B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1))
|
B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1))
|
||||||
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
|
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
|
||||||
|
|
||||||
# Get parameters for attention computation
|
# Process with parallel attention
|
||||||
A = -mx.exp(self.A_log)
|
A = -mx.exp(self.A_log)
|
||||||
|
|
||||||
# Compute parallel attention
|
|
||||||
y, next_state = ssd_forward_attn(
|
y, next_state = ssd_forward_attn(
|
||||||
x=x,
|
x=x,
|
||||||
dt=dt,
|
dt=dt,
|
||||||
@ -244,17 +239,24 @@ class Mamba2Block(nn.Module):
|
|||||||
D=self.D,
|
D=self.D,
|
||||||
dt_bias=self.dt_bias,
|
dt_bias=self.dt_bias,
|
||||||
dt_min=self.args.time_step_min,
|
dt_min=self.args.time_step_min,
|
||||||
dt_max=self.args.time_step_max,
|
dt_max=self.args.time_step_max
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update cache
|
# Apply normalization based on norm_before_gate setting
|
||||||
if cache is not None:
|
if self.args.norm_before_gate:
|
||||||
cache[1] = next_state
|
y = self.norm(y)
|
||||||
|
y = y * nn.silu(z)
|
||||||
# Apply normalization and output projection
|
else:
|
||||||
y = self.norm(y, z)
|
y = y * nn.silu(z)
|
||||||
|
y = self.norm(y)
|
||||||
|
|
||||||
|
# Final projection
|
||||||
y = self.out_proj(y)
|
y = self.out_proj(y)
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
cache[0] = conv_state
|
||||||
|
cache[1] = next_state
|
||||||
|
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user