Switch to fast RMS/LN Norm (#603)

* use nn.RMSNorm, use sdpa, cleanup

* bump mlx versions

* minor update

* use fast layer norm

* version bump

* update requirement for whisper

* update requirement for gguf
This commit is contained in:
Awni Hannun
2024-03-23 07:13:51 -07:00
committed by GitHub
parent fbed720d6f
commit b8a348c1b8
44 changed files with 144 additions and 1155 deletions

View File

@@ -5,7 +5,6 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
@@ -114,8 +113,10 @@ class TransformerBlock(nn.Module):
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
@@ -142,7 +143,7 @@ class Qwen2Model(nn.Module):
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,