This commit is contained in:
Goekdeniz-Guelmez 2024-12-27 15:41:54 +01:00
parent 3384d38a83
commit 4e94e87f57

View File

@ -186,8 +186,7 @@ class Mamba2Block(nn.Module):
# Update state - matches PyTorch implementation
next_state = (
next_state * mx.expand_dims(dAt, axis=(-1, -2)) +
dBx
next_state * mx.expand_dims(dAt, axis=(-1, -2)) + dBx
)
# Compute output