added mx.einsum() operations: before: 41.293 tokens-per-sec, after: 57.822 tokens-per-sec

This commit is contained in:
Goekdeniz-Guelmez 2025-01-20 18:37:58 +01:00
parent 07f88f8057
commit e43ac7c90e

View File

@ -138,10 +138,10 @@ class MambaBlock(nn.Module):
if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta))
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
new_state = mx.einsum('bs,bs,sd->bsd', delta, x, B)
if state is not None:
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
y = mx.einsum('bsd,sd->bs', new_state, C)
y = y + D * x
return y, new_state