mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-22 13:07:55 +08:00
Fused Operations in delta, B, C = ... :. Before: 57.822 tokens-per-sec, after: 83.890 tokens-per-sec
This commit is contained in:
parent
e43ac7c90e
commit
9494a275ac
@ -127,14 +127,10 @@ class MambaBlock(nn.Module):
|
|||||||
A = -mx.exp(self.A_log)
|
A = -mx.exp(self.A_log)
|
||||||
D = self.D
|
D = self.D
|
||||||
deltaBC = self.x_proj(x)
|
deltaBC = self.x_proj(x)
|
||||||
delta, B, C = mx.split(
|
delta, B, C = map(self.mixer_norm if self.use_bcdt_rms else lambda x: x,
|
||||||
deltaBC,
|
mx.split(deltaBC, [self.time_step_rank,
|
||||||
indices_or_sections=[
|
self.time_step_rank + self.ssm_state_size],
|
||||||
self.time_step_rank,
|
axis=-1))
|
||||||
self.time_step_rank + self.ssm_state_size,
|
|
||||||
],
|
|
||||||
axis=-1,
|
|
||||||
)
|
|
||||||
if self.use_bcdt_rms:
|
if self.use_bcdt_rms:
|
||||||
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||||
delta = nn.softplus(self.dt_proj(delta))
|
delta = nn.softplus(self.dt_proj(delta))
|
||||||
|
Loading…
Reference in New Issue
Block a user