This commit is contained in:
Awni Hannun 2025-02-24 09:07:07 -08:00
parent 9392bc70f7
commit 2edcc0355f

View File

@ -310,14 +310,11 @@ class Mamba(nn.Module):
axis=-1,
)
# conv
x = x.reshape(bsize, -1, self.num_heads * self.hidden_size_per_head)
x, conv_state = causal_conv1d_update(conv_state, x, self.conv1d.weight)
BCdt = self.bcdt_proj(x)
x = x.reshape(bsize, length, self.num_heads, -1)
B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1)
B = B[:, :, None, :]
C = C[:, :, None, :]
A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
dt = mx.fast.rms_norm(dt, self.dt_norm_weight, self.config.rms_norm_eps)
@ -327,10 +324,6 @@ class Mamba(nn.Module):
# (bsize, length, num_heads, 1)
dt = self.dt_proj(dt)[..., None]
# TODO it may not be required
B = mx.broadcast_to(B, (B.shape[0], B.shape[1], self.num_heads, B.shape[3]))
C = mx.broadcast_to(C, (C.shape[0], C.shape[1], self.num_heads, C.shape[3]))
out, ssm_state = ssd_chunk_scan_combined(
x,
dt.reshape(bsize, length, -1),