diff --git a/llms/mlx_lm/models/plamo2.py b/llms/mlx_lm/models/plamo2.py index ba4c2bb1..5935024b 100644 --- a/llms/mlx_lm/models/plamo2.py +++ b/llms/mlx_lm/models/plamo2.py @@ -310,14 +310,11 @@ class Mamba(nn.Module): axis=-1, ) - # conv x = x.reshape(bsize, -1, self.num_heads * self.hidden_size_per_head) x, conv_state = causal_conv1d_update(conv_state, x, self.conv1d.weight) BCdt = self.bcdt_proj(x) x = x.reshape(bsize, length, self.num_heads, -1) B, C, dt = mx.split(BCdt, [self.d_state, self.d_state * 2], axis=-1) - B = B[:, :, None, :] - C = C[:, :, None, :] A = -mx.exp(self.A_log.astype(mx.float32)) # (num_heads,) dt = mx.fast.rms_norm(dt, self.dt_norm_weight, self.config.rms_norm_eps) @@ -327,10 +324,6 @@ class Mamba(nn.Module): # (bsize, length, num_heads, 1) dt = self.dt_proj(dt)[..., None] - # TODO it may not be required - B = mx.broadcast_to(B, (B.shape[0], B.shape[1], self.num_heads, B.shape[3])) - C = mx.broadcast_to(C, (C.shape[0], C.shape[1], self.num_heads, C.shape[3])) - out, ssm_state = ssd_chunk_scan_combined( x, dt.reshape(bsize, length, -1),