mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
parent
2bd64b78cf
commit
c386dd5f5a
@ -21,6 +21,18 @@ class ModelArgs(BaseModelArgs):
|
||||
logit_scale: float = 0.0625
|
||||
attention_bias: bool = False
|
||||
layer_norm_bias: bool = False
|
||||
use_qk_norm: bool = False
|
||||
|
||||
|
||||
class LayerNorm2D(nn.Module):
|
||||
|
||||
def __init__(self, d1, d2, eps):
|
||||
super().__init__()
|
||||
self.weight = mx.zeros((d1, d2))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return self.weight * mx.fast.layer_norm(x, None, None, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
@ -42,6 +54,13 @@ class Attention(nn.Module):
|
||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
|
||||
|
||||
self.use_qk_norm = args.use_qk_norm
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = LayerNorm2D(self.n_heads, head_dim, eps=args.layer_norm_eps)
|
||||
self.k_norm = LayerNorm2D(
|
||||
self.n_kv_heads, head_dim, eps=args.layer_norm_eps
|
||||
)
|
||||
|
||||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
||||
|
||||
def __call__(
|
||||
@ -54,9 +73,14 @@ class Attention(nn.Module):
|
||||
|
||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||
|
||||
# Prepare the queries, keys and values for the attention computation
|
||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
queries = queries.reshape(B, L, self.n_heads, -1)
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1)
|
||||
if self.use_qk_norm:
|
||||
queries = self.q_norm(queries)
|
||||
keys = self.k_norm(keys)
|
||||
|
||||
queries = queries.transpose(0, 2, 1, 3)
|
||||
keys = keys.transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
|
@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.6.0"
|
||||
__version__ = "0.7.0"
|
||||
|
Loading…
Reference in New Issue
Block a user