mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-20 18:26:39 +08:00
nits + even faster
This commit is contained in:
parent
a7e35ee748
commit
c43fc7ce59
@ -127,60 +127,57 @@ class MambaBlock(nn.Module):
|
|||||||
D = self.D
|
D = self.D
|
||||||
deltaBC = self.x_proj(x)
|
deltaBC = self.x_proj(x)
|
||||||
delta, B, C = map(
|
delta, B, C = map(
|
||||||
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
|
self.mixer_norm if self.use_bcdt_rms else lambda x: x,
|
||||||
mx.split(
|
mx.split(
|
||||||
deltaBC,
|
deltaBC,
|
||||||
[
|
[self.time_step_rank, self.time_step_rank + self.ssm_state_size],
|
||||||
self.time_step_rank,
|
axis=-1,
|
||||||
self.time_step_rank + self.ssm_state_size
|
),
|
||||||
],
|
|
||||||
axis=-1
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if self.use_bcdt_rms:
|
if self.use_bcdt_rms:
|
||||||
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
delta, B, C = map(self.mixer_norm, (delta, B, C))
|
||||||
delta = nn.softplus(self.dt_proj(delta))
|
delta = nn.softplus(self.dt_proj(delta))
|
||||||
new_state = mx.einsum('bs,bs,sd->bsd', delta, x, B)
|
new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
|
||||||
if state is not None:
|
if state is not None:
|
||||||
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
|
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
|
||||||
y = mx.einsum('bsd,sd->bs', new_state, C)
|
y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
|
||||||
y = y + D * x
|
y = y + D * x
|
||||||
return y, new_state
|
return y, new_state
|
||||||
|
|
||||||
def _process_sequence(self, x, conv_cache, state_cache):
|
def _process_sequence(self, x, conv_cache, state_cache):
|
||||||
B, T, D = x.shape
|
B, T, D = x.shape
|
||||||
xz = self.in_proj(x.reshape(-1, D)).reshape(B, T, -1)
|
xz = self.in_proj(x)
|
||||||
x_t, z_t = xz.split(indices_or_sections=2, axis=-1)
|
x, z = xz.split(indices_or_sections=2, axis=-1)
|
||||||
|
|
||||||
conv_out, new_conv_cache = self.conv1d(x_t, conv_cache)
|
conv_out, new_conv_cache = self.conv1d(x, conv_cache)
|
||||||
x_t = nn.silu(conv_out)
|
x = nn.silu(conv_out)
|
||||||
|
|
||||||
A = -mx.exp(self.A_log)
|
A = -mx.exp(self.A_log)
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
current_state = state_cache
|
current_state = state_cache
|
||||||
|
y = []
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
y_t, current_state = self.ssm_step(x_t[:, t], A, current_state)
|
y_t, current_state = self.ssm_step(x[:, t], A, current_state)
|
||||||
z_curr = nn.silu(z_t[:, t])
|
y.append(y_t)
|
||||||
output_t = self.out_proj(y_t * z_curr)
|
y = mx.stack(y, axis=1)
|
||||||
outputs.append(output_t)
|
z = self.out_proj(nn.silu(z) * y)
|
||||||
|
return z, (new_conv_cache, current_state)
|
||||||
return mx.stack(outputs, axis=1), (new_conv_cache, current_state)
|
|
||||||
|
|
||||||
def __call__(self, x, cache):
|
def __call__(self, x, cache):
|
||||||
if cache is None or isinstance(cache, list):
|
if cache is None:
|
||||||
conv_cache, state_cache = cache if cache is not None else (None, None)
|
conv_cache, state_cache = None, None
|
||||||
else:
|
else:
|
||||||
conv_cache, state_cache = cache.state
|
conv_cache, state_cache = cache[0], cache[1]
|
||||||
|
|
||||||
output, (new_conv_cache, new_state_cache) = self._process_sequence(
|
output, (new_conv_cache, new_state_cache) = self._process_sequence(
|
||||||
x, conv_cache, state_cache
|
x, conv_cache, state_cache
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(cache, MambaCache):
|
if isinstance(cache, MambaCache):
|
||||||
cache[0] = new_conv_cache
|
cache[0] = new_conv_cache
|
||||||
cache[1] = new_state_cache
|
cache[1] = new_state_cache
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user