Pre-computing A_log. After: 83.890 tokens-per-sec, before: 85.848 tokens-per-sec

This commit is contained in:
Goekdeniz-Guelmez 2025-01-20 18:42:39 +01:00
parent 9494a275ac
commit db582e4f9e

View File

@ -123,14 +123,20 @@ class MambaBlock(nn.Module):
self.intermediate_size, self.hidden_size, bias=args.use_bias
)
def ssm_step(self, x, state=None):
A = -mx.exp(self.A_log)
def ssm_step(self, x, A, state=None):
D = self.D
deltaBC = self.x_proj(x)
delta, B, C = map(self.mixer_norm if self.use_bcdt_rms else lambda x: x,
mx.split(deltaBC, [self.time_step_rank,
self.time_step_rank + self.ssm_state_size],
axis=-1))
delta, B, C = map(
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
mx.split(
deltaBC,
[
self.time_step_rank,
self.time_step_rank + self.ssm_state_size
],
axis=-1
)
)
if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta))
@ -143,6 +149,9 @@ class MambaBlock(nn.Module):
def __call__(self, x, cache):
B, T, D = x.shape
A = -mx.exp(self.A_log)
if cache is None:
cache = [None, None]
@ -154,7 +163,7 @@ class MambaBlock(nn.Module):
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1])
y_t, cache[1] = self.ssm_step(x_t, A, cache[1])
z_t = nn.silu(z_t)
output_t = y_t * z_t
output_t = self.out_proj(output_t)