nits + even faster

This commit is contained in:
Awni Hannun 2025-02-03 12:53:11 -08:00
parent a7e35ee748
commit c43fc7ce59

View File

@ -130,48 +130,45 @@ class MambaBlock(nn.Module):
self.mixer_norm if self.use_bcdt_rms else lambda x: x, self.mixer_norm if self.use_bcdt_rms else lambda x: x,
mx.split( mx.split(
deltaBC, deltaBC,
[ [self.time_step_rank, self.time_step_rank + self.ssm_state_size],
self.time_step_rank, axis=-1,
self.time_step_rank + self.ssm_state_size ),
],
axis=-1
)
) )
if self.use_bcdt_rms: if self.use_bcdt_rms:
delta, B, C = map(self.mixer_norm, (delta, B, C)) delta, B, C = map(self.mixer_norm, (delta, B, C))
delta = nn.softplus(self.dt_proj(delta)) delta = nn.softplus(self.dt_proj(delta))
new_state = mx.einsum('bs,bs,sd->bsd', delta, x, B) new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(B, 1)
if state is not None: if state is not None:
new_state += state * mx.exp(mx.expand_dims(delta, -1) * A) new_state += state * mx.exp(mx.expand_dims(delta, -1) * A)
y = mx.einsum('bsd,sd->bs', new_state, C) y = (new_state @ mx.expand_dims(C, -1)).squeeze(2)
y = y + D * x y = y + D * x
return y, new_state return y, new_state
def _process_sequence(self, x, conv_cache, state_cache): def _process_sequence(self, x, conv_cache, state_cache):
B, T, D = x.shape B, T, D = x.shape
xz = self.in_proj(x.reshape(-1, D)).reshape(B, T, -1) xz = self.in_proj(x)
x_t, z_t = xz.split(indices_or_sections=2, axis=-1) x, z = xz.split(indices_or_sections=2, axis=-1)
conv_out, new_conv_cache = self.conv1d(x_t, conv_cache) conv_out, new_conv_cache = self.conv1d(x, conv_cache)
x_t = nn.silu(conv_out) x = nn.silu(conv_out)
A = -mx.exp(self.A_log) A = -mx.exp(self.A_log)
outputs = [] outputs = []
current_state = state_cache current_state = state_cache
y = []
for t in range(T): for t in range(T):
y_t, current_state = self.ssm_step(x_t[:, t], A, current_state) y_t, current_state = self.ssm_step(x[:, t], A, current_state)
z_curr = nn.silu(z_t[:, t]) y.append(y_t)
output_t = self.out_proj(y_t * z_curr) y = mx.stack(y, axis=1)
outputs.append(output_t) z = self.out_proj(nn.silu(z) * y)
return z, (new_conv_cache, current_state)
return mx.stack(outputs, axis=1), (new_conv_cache, current_state)
def __call__(self, x, cache): def __call__(self, x, cache):
if cache is None or isinstance(cache, list): if cache is None:
conv_cache, state_cache = cache if cache is not None else (None, None) conv_cache, state_cache = None, None
else: else:
conv_cache, state_cache = cache.state conv_cache, state_cache = cache[0], cache[1]
output, (new_conv_cache, new_state_cache) = self._process_sequence( output, (new_conv_cache, new_state_cache) = self._process_sequence(
x, conv_cache, state_cache x, conv_cache, state_cache