This commit is contained in:
Goekdeniz-Guelmez 2024-12-27 15:41:54 +01:00
parent 3384d38a83
commit 4e94e87f57

View File

@ -186,8 +186,7 @@ class Mamba2Block(nn.Module):
# Update state - matches PyTorch implementation # Update state - matches PyTorch implementation
next_state = ( next_state = (
next_state * mx.expand_dims(dAt, axis=(-1, -2)) + next_state * mx.expand_dims(dAt, axis=(-1, -2)) + dBx
dBx
) )
# Compute output # Compute output