codestral inference exxtually works now

This commit is contained in:
Goekdeniz-Guelmez 2025-01-21 21:01:39 +01:00
parent 5a6ada2df0
commit a6a92cb91f

View File

@ -161,7 +161,7 @@ class Mamba2Block(nn.Module):
super().__init__()
self.args = args
# Dimensions
# Same dimensions as before
self.d_model = args.hidden_size
self.d_state = args.state_size
self.d_conv = args.conv_kernel
@ -190,51 +190,46 @@ class Mamba2Block(nn.Module):
)
# Output projections
self.norm = MambaRMSNormGated(
self.d_inner,
eps=args.layer_norm_epsilon,
norm_before_gate=args.norm_before_gate
)
self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon)
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias)
def __call__(self, u: mx.array, cache=None):
batch_size, seq_len, _ = u.shape
if cache is None:
cache = [None, None]
# Get or initialize states from cache
if cache is None:
cache = [None, None] # [conv_state, ssm_state]
conv_state, _ = cache # We ignore ssm_state as it's not used in the parallel version
# Project input
zxBCdt = self.in_proj(u)
# Split projections
z, xBC, dt = mx.split(
zxBCdt,
[self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state],
axis=-1
)
# Process convolution
xBC, conv_state = self.conv1d(xBC, cache[0])
xBC, conv_state = self.conv1d(xBC, conv_state)
xBC = silu(xBC)
if cache is not None:
cache[0] = conv_state
xBC = xBC[:, :seq_len, :]
# Split and reshape conv output
# Split conv output
x, B, C = mx.split(
xBC,
[self.d_inner, self.d_inner + self.d_state * self.n_groups],
axis=-1
)
# Reshape for SSM processing
x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head))
B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1))
C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1))
# Get parameters for attention computation
# Process with parallel attention
A = -mx.exp(self.A_log)
# Compute parallel attention
y, next_state = ssd_forward_attn(
x=x,
dt=dt,
@ -244,17 +239,24 @@ class Mamba2Block(nn.Module):
D=self.D,
dt_bias=self.dt_bias,
dt_min=self.args.time_step_min,
dt_max=self.args.time_step_max,
dt_max=self.args.time_step_max
)
# Update cache
if cache is not None:
cache[1] = next_state
# Apply normalization and output projection
y = self.norm(y, z)
# Apply normalization based on norm_before_gate setting
if self.args.norm_before_gate:
y = self.norm(y)
y = y * nn.silu(z)
else:
y = y * nn.silu(z)
y = self.norm(y)
# Final projection
y = self.out_proj(y)
# Update cache
cache[0] = conv_state
cache[1] = next_state
return y