Fused Operations in delta, B, C = ... :. Before: 57.822 tokens-per-sec, after: 83.890 tokens-per-sec

This commit is contained in:
Goekdeniz-Guelmez 2025-01-20 18:39:22 +01:00
parent e43ac7c90e
commit 9494a275ac

View File

@ -127,14 +127,10 @@ class MambaBlock(nn.Module):
A = -mx.exp(self.A_log) A = -mx.exp(self.A_log)
D = self.D D = self.D
deltaBC = self.x_proj(x) deltaBC = self.x_proj(x)
delta, B, C = mx.split( delta, B, C = map(self.mixer_norm if self.use_bcdt_rms else lambda x: x,
deltaBC, mx.split(deltaBC, [self.time_step_rank,
indices_or_sections=[ self.time_step_rank + self.ssm_state_size],
self.time_step_rank, axis=-1))
self.time_step_rank + self.ssm_state_size,
],
axis=-1,
)
if self.use_bcdt_rms: if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C)) delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta)) delta = nn.softplus(self.dt_proj(delta))