mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 04:25:06 +08:00
cleanup
This commit is contained in:
parent
9392bc70f7
commit
2edcc0355f
@ -310,14 +310,11 @@ class Mamba(nn.Module):
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
# conv
|
||||
x = x.reshape(bsize, -1, self.num_heads * self.hidden_size_per_head)
|
||||
x, conv_state = causal_conv1d_update(conv_state, x, self.conv1d.weight)
|
||||
BCdt = self.bcdt_proj(x)
|
||||
x = x.reshape(bsize, length, self.num_heads, -1)
|
||||
B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1)
|
||||
B = B[:, :, None, :]
|
||||
C = C[:, :, None, :]
|
||||
|
||||
A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,)
|
||||
dt = mx.fast.rms_norm(dt, self.dt_norm_weight, self.config.rms_norm_eps)
|
||||
@ -327,10 +324,6 @@ class Mamba(nn.Module):
|
||||
# (bsize, length, num_heads, 1)
|
||||
dt = self.dt_proj(dt)[..., None]
|
||||
|
||||
# TODO it may not be required
|
||||
B = mx.broadcast_to(B, (B.shape[0], B.shape[1], self.num_heads, B.shape[3]))
|
||||
C = mx.broadcast_to(C, (C.shape[0], C.shape[1], self.num_heads, C.shape[3]))
|
||||
|
||||
out, ssm_state = ssd_chunk_scan_combined(
|
||||
x,
|
||||
dt.reshape(bsize, length, -1),
|
||||
|
Loading…
Reference in New Issue
Block a user