mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 18:51:18 +08:00
fix generation works too (almost)
This commit is contained in:
parent
181d6abedc
commit
cd036ccfb5
@ -157,11 +157,12 @@ class Mamba2Mixer(nn.Module):
|
|||||||
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
|
B, C = mx.split(x, indices_or_sections=[self.state_size * self.n_groups], axis=-1)
|
||||||
print(f"ssm_step split shapes - B: {B.shape}, C: {C.shape}")
|
print(f"ssm_step split shapes - B: {B.shape}, C: {C.shape}")
|
||||||
|
|
||||||
B = B.reshape(-1, self.n_groups, self.state_size)
|
batch_size = B.shape[0]
|
||||||
C = C.reshape(-1, self.n_groups, self.state_size)
|
B = B.reshape(batch_size, self.n_groups, self.state_size)
|
||||||
|
C = C.reshape(batch_size, -1, self.state_size)
|
||||||
print(f"After reshape - B: {B.shape}, C: {C.shape}")
|
print(f"After reshape - B: {B.shape}, C: {C.shape}")
|
||||||
|
|
||||||
delta = delta.reshape(-1, self.num_heads, 1)
|
delta = delta.reshape(batch_size, self.num_heads, 1)
|
||||||
A = A.reshape(1, self.num_heads, 1)
|
A = A.reshape(1, self.num_heads, 1)
|
||||||
|
|
||||||
if state is None:
|
if state is None:
|
||||||
@ -170,7 +171,7 @@ class Mamba2Mixer(nn.Module):
|
|||||||
new_state = delta * (B + state * mx.exp(delta * A))
|
new_state = delta * (B + state * mx.exp(delta * A))
|
||||||
|
|
||||||
print(f"Before final computation - new_state: {new_state.shape}, C: {C.shape}")
|
print(f"Before final computation - new_state: {new_state.shape}, C: {C.shape}")
|
||||||
y = mx.sum(new_state * C, axis=-1)
|
y = mx.sum(new_state[:, :, None, :] * C[:, None, :, :], axis=(-1, -2))
|
||||||
y = y + D * x[:, :self.num_heads]
|
y = y + D * x[:, :self.num_heads]
|
||||||
print(f"ssm_step output shape - y: {y.shape}")
|
print(f"ssm_step output shape - y: {y.shape}")
|
||||||
return y, new_state
|
return y, new_state
|
||||||
|
Loading…
Reference in New Issue
Block a user