Pre-computing A_log. After: 83.890 tokens-per-sec, before: 85.848 tokens-per-sec

This commit is contained in:
Goekdeniz-Guelmez 2025-01-20 18:42:39 +01:00
parent 9494a275ac
commit db582e4f9e

View File

@ -123,14 +123,20 @@ class MambaBlock(nn.Module):
self.intermediate_size, self.hidden_size, bias=args.use_bias self.intermediate_size, self.hidden_size, bias=args.use_bias
) )
def ssm_step(self, x, state=None): def ssm_step(self, x, A, state=None):
A = -mx.exp(self.A_log)
D = self.D D = self.D
deltaBC = self.x_proj(x) deltaBC = self.x_proj(x)
delta, B, C = map(self.mixer_norm if self.use_bcdt_rms else lambda x: x, delta, B, C = map(
mx.split(deltaBC, [self.time_step_rank, self.mixer_norm if self.use_bcdt_rms else lambda x: x,
self.time_step_rank + self.ssm_state_size], mx.split(
axis=-1)) deltaBC,
[
self.time_step_rank,
self.time_step_rank + self.ssm_state_size
],
axis=-1
)
)
if self.use_bcdt_rms: if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C)) delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta)) delta = nn.softplus(self.dt_proj(delta))
@ -143,6 +149,9 @@ class MambaBlock(nn.Module):
def __call__(self, x, cache): def __call__(self, x, cache):
B, T, D = x.shape B, T, D = x.shape
A = -mx.exp(self.A_log)
if cache is None: if cache is None:
cache = [None, None] cache = [None, None]
@ -154,7 +163,7 @@ class MambaBlock(nn.Module):
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
x_t = conv_out.squeeze(1) x_t = conv_out.squeeze(1)
x_t = nn.silu(x_t) x_t = nn.silu(x_t)
y_t, cache[1] = self.ssm_step(x_t, cache[1]) y_t, cache[1] = self.ssm_step(x_t, A, cache[1])
z_t = nn.silu(z_t) z_t = nn.silu(z_t)
output_t = y_t * z_t output_t = y_t * z_t
output_t = self.out_proj(output_t) output_t = self.out_proj(output_t)