mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-21 12:06:51 +08:00
added mx.einsum() operations: before: 41.293 tokens-per-sec, after: 57.822 tokens-per-sec
This commit is contained in:
parent
07f88f8057
commit
e43ac7c90e
@ -138,10 +138,10 @@ class MambaBlock(nn.Module):
|
||||
if self.use_bcdt_rms:
|
||||
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||
delta = nn.softplus(self.dt_proj(delta))
|
||||
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
||||
new_state = mx.einsum('bs,bs,sd->bsd', delta, x, B)
|
||||
if state is not None:
|
||||
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
|
||||
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
|
||||
y = mx.einsum('bsd,sd->bs', new_state, C)
|
||||
y = y + D * x
|
||||
return y, new_state
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user