mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 10:41:18 +08:00
parent
2bd64b78cf
commit
c386dd5f5a
@ -21,6 +21,18 @@ class ModelArgs(BaseModelArgs):
|
|||||||
logit_scale: float = 0.0625
|
logit_scale: float = 0.0625
|
||||||
attention_bias: bool = False
|
attention_bias: bool = False
|
||||||
layer_norm_bias: bool = False
|
layer_norm_bias: bool = False
|
||||||
|
use_qk_norm: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class LayerNorm2D(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, d1, d2, eps):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = mx.zeros((d1, d2))
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
return self.weight * mx.fast.layer_norm(x, None, None, self.eps)
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
@ -42,6 +54,13 @@ class Attention(nn.Module):
|
|||||||
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attetion_bias)
|
||||||
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
|
self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attetion_bias)
|
||||||
|
|
||||||
|
self.use_qk_norm = args.use_qk_norm
|
||||||
|
if self.use_qk_norm:
|
||||||
|
self.q_norm = LayerNorm2D(self.n_heads, head_dim, eps=args.layer_norm_eps)
|
||||||
|
self.k_norm = LayerNorm2D(
|
||||||
|
self.n_kv_heads, head_dim, eps=args.layer_norm_eps
|
||||||
|
)
|
||||||
|
|
||||||
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
self.rope = nn.RoPE(head_dim, traditional=True, base=args.rope_theta)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
@ -54,9 +73,14 @@ class Attention(nn.Module):
|
|||||||
|
|
||||||
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x)
|
||||||
|
|
||||||
# Prepare the queries, keys and values for the attention computation
|
queries = queries.reshape(B, L, self.n_heads, -1)
|
||||||
queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
|
keys = keys.reshape(B, L, self.n_kv_heads, -1)
|
||||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
if self.use_qk_norm:
|
||||||
|
queries = self.q_norm(queries)
|
||||||
|
keys = self.k_norm(keys)
|
||||||
|
|
||||||
|
queries = queries.transpose(0, 2, 1, 3)
|
||||||
|
keys = keys.transpose(0, 2, 1, 3)
|
||||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||||
|
|
||||||
if cache is not None:
|
if cache is not None:
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.6.0"
|
__version__ = "0.7.0"
|
||||||
|
Loading…
Reference in New Issue
Block a user