diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index d63a3e3e..c26e2925 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -28,7 +28,6 @@ class ModelArgs(BaseModelArgs): rms_norm: bool chunk_size: int tie_word_embeddings: bool - intermediate_size: int time_step_limit: Tuple[float, float] time_step_rank: Union[int, str] time_step_min: float @@ -106,15 +105,6 @@ class Mamba2Block(nn.Module): # Input projection d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias) - - # Improved initialization of dt - dt = mx.exp( - mx.random.uniform( - low=math.log(args.time_step_min), - high=math.log(args.time_step_max), - shape=(self.n_heads,) - ) - ) self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range