diff --git a/llms/mlx_lm/models/olmoe.py b/llms/mlx_lm/models/olmoe.py index d5d9b425..1297ce83 100644 --- a/llms/mlx_lm/models/olmoe.py +++ b/llms/mlx_lm/models/olmoe.py @@ -48,15 +48,11 @@ class Attention(nn.Module): self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads self.scale = head_dim**-0.5 - if hasattr(args, "attention_bias"): - attention_bias = args.attention_bias - else: - attention_bias = False - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) self.rope = initialize_rope( self.head_dim,