fix generation works too (almost)

This commit is contained in:
Goekdeniz-Guelmez 2024-10-16 21:13:36 +02:00
parent 181d6abedc
commit cd036ccfb5

View File

@ -157,11 +157,12 @@ class Mamba2Mixer(nn.Module):
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
print(f"ssm_step split shapes - B: {B.shape}, C: {C.shape}")
B = B.reshape(-1, self.n_groups, self.state_size)
C = C.reshape(-1, self.n_groups, self.state_size)
batch_size = B.shape[0]
B = B.reshape(batch_size, self.n_groups, self.state_size)
C = C.reshape(batch_size, -1, self.state_size)
print(f"After reshape - B: {B.shape}, C: {C.shape}")
delta = delta.reshape(-1, self.num_heads, 1)
delta = delta.reshape(batch_size, self.num_heads, 1)
A = A.reshape(1, self.num_heads, 1)
if state is None:
@ -170,7 +171,7 @@ class Mamba2Mixer(nn.Module):
new_state = delta * (B + state * mx.exp(delta * A))
print(f"Before final computation - new_state: {new_state.shape}, C: {C.shape}")
y = mx.sum(new_state * C, axis=-1)
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
y = y + D * x[:, :self.num_heads]
print(f"ssm_step output shape - y: {y.shape}")
return y, new_state