mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-21 12:06:51 +08:00
Pre-computing A_log. After: 83.890 tokens-per-sec, before: 85.848 tokens-per-sec
This commit is contained in:
parent
9494a275ac
commit
db582e4f9e
@ -123,14 +123,20 @@ class MambaBlock(nn.Module):
|
|||||||
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
self.intermediate_size, self.hidden_size, bias=args.use_bias
|
||||||
)
|
)
|
||||||
|
|
||||||
def ssm_step(self, x, state=None):
|
def ssm_step(self, x, A, state=None):
|
||||||
A = -mx.exp(self.A_log)
|
|
||||||
D = self.D
|
D = self.D
|
||||||
deltaBC = self.x_proj(x)
|
deltaBC = self.x_proj(x)
|
||||||
delta, B, C = map(self.mixer_norm if self.use_bcdt_rms else lambda x: x,
|
delta, B, C = map(
|
||||||
mx.split(deltaBC, [self.time_step_rank,
|
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
|
||||||
self.time_step_rank + self.ssm_state_size],
|
mx.split(
|
||||||
axis=-1))
|
deltaBC,
|
||||||
|
[
|
||||||
|
self.time_step_rank,
|
||||||
|
self.time_step_rank + self.ssm_state_size
|
||||||
|
],
|
||||||
|
axis=-1
|
||||||
|
)
|
||||||
|
)
|
||||||
if self.use_bcdt_rms:
|
if self.use_bcdt_rms:
|
||||||
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||||
delta = nn.softplus(self.dt_proj(delta))
|
delta = nn.softplus(self.dt_proj(delta))
|
||||||
@ -143,6 +149,9 @@ class MambaBlock(nn.Module):
|
|||||||
|
|
||||||
def __call__(self, x, cache):
|
def __call__(self, x, cache):
|
||||||
B, T, D = x.shape
|
B, T, D = x.shape
|
||||||
|
|
||||||
|
A = -mx.exp(self.A_log)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None, None]
|
cache = [None, None]
|
||||||
|
|
||||||
@ -154,7 +163,7 @@ class MambaBlock(nn.Module):
|
|||||||
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
|
conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0])
|
||||||
x_t = conv_out.squeeze(1)
|
x_t = conv_out.squeeze(1)
|
||||||
x_t = nn.silu(x_t)
|
x_t = nn.silu(x_t)
|
||||||
y_t, cache[1] = self.ssm_step(x_t, cache[1])
|
y_t, cache[1] = self.ssm_step(x_t, A, cache[1])
|
||||||
z_t = nn.silu(z_t)
|
z_t = nn.silu(z_t)
|
||||||
output_t = y_t * z_t
|
output_t = y_t * z_t
|
||||||
output_t = self.out_proj(output_t)
|
output_t = self.out_proj(output_t)
|
||||||
|
Loading…
Reference in New Issue
Block a user