mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 10:41:18 +08:00
fix generation works too (almost)
This commit is contained in:
parent
181d6abedc
commit
cd036ccfb5
@ -157,11 +157,12 @@ class Mamba2Mixer(nn.Module):
|
||||
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
|
||||
print(f"ssm_step split shapes - B: {B.shape}, C: {C.shape}")
|
||||
|
||||
B = B.reshape(-1, self.n_groups, self.state_size)
|
||||
C = C.reshape(-1, self.n_groups, self.state_size)
|
||||
batch_size = B.shape[0]
|
||||
B = B.reshape(batch_size, self.n_groups, self.state_size)
|
||||
C = C.reshape(batch_size, -1, self.state_size)
|
||||
print(f"After reshape - B: {B.shape}, C: {C.shape}")
|
||||
|
||||
delta = delta.reshape(-1, self.num_heads, 1)
|
||||
delta = delta.reshape(batch_size, self.num_heads, 1)
|
||||
A = A.reshape(1, self.num_heads, 1)
|
||||
|
||||
if state is None:
|
||||
@ -170,7 +171,7 @@ class Mamba2Mixer(nn.Module):
|
||||
new_state = delta * (B + state * mx.exp(delta * A))
|
||||
|
||||
print(f"Before final computation - new_state: {new_state.shape}, C: {C.shape}")
|
||||
y = mx.sum(new_state * C, axis=-1)
|
||||
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
|
||||
y = y + D * x[:, :self.num_heads]
|
||||
print(f"ssm_step output shape - y: {y.shape}")
|
||||
return y, new_state
|
||||
|
Loading…
Reference in New Issue
Block a user