From 230abad4260e0ff4e1a3dd96f566e6675b2bdb06 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 4 Mar 2025 22:44:24 +0100 Subject: [PATCH] more clean ups --- llms/mlx_lm/models/olmoe.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/llms/mlx_lm/models/olmoe.py b/llms/mlx_lm/models/olmoe.py index d5d9b425..1297ce83 100644 --- a/llms/mlx_lm/models/olmoe.py +++ b/llms/mlx_lm/models/olmoe.py @@ -48,15 +48,11 @@ class Attention(nn.Module): self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads self.scale = head_dim**-0.5 - if hasattr(args, "attention_bias"): - attention_bias = args.attention_bias - else: - attention_bias = False - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias) self.rope = initialize_rope( self.head_dim,