mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 04:25:06 +08:00
more clean ups
This commit is contained in:
parent
d828bc0c2d
commit
230abad426
@ -48,15 +48,11 @@ class Attention(nn.Module):
|
|||||||
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
|
self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads
|
||||||
|
|
||||||
self.scale = head_dim**-0.5
|
self.scale = head_dim**-0.5
|
||||||
if hasattr(args, "attention_bias"):
|
|
||||||
attention_bias = args.attention_bias
|
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=args.attention_bias)
|
||||||
else:
|
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||||
attention_bias = False
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=args.attention_bias)
|
||||||
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=args.attention_bias)
|
||||||
self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias)
|
|
||||||
self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
|
||||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias)
|
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias)
|
|
||||||
|
|
||||||
self.rope = initialize_rope(
|
self.rope = initialize_rope(
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
|
Loading…
Reference in New Issue
Block a user