diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index daa2954c..e123be74 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -98,10 +98,7 @@ class Mamba2Block(nn.Module): self.d_state = args.state_size self.d_conv = args.conv_kernel self.expand = args.expand - if args.intermediate_size == None: - self.d_inner = int(self.expand * self.d_model) - else: - self.d_inner = args.intermediate_size + self.d_inner = args.intermediate_size or int(self.expand * self.d_model) self.n_groups = args.n_groups self.n_heads = args.num_heads self.d_head = self.d_inner // self.n_heads