mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Switch to fast RMS/LN Norm (#603)
* use nn.RMSNorm, use sdpa, cleanup * bump mlx versions * minor update * use fast layer norm * version bump * update requirement for whisper * update requirement for gguf
This commit is contained in:
@@ -5,7 +5,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -102,9 +101,9 @@ class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.ln_1 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.attn = Attention(args)
|
||||
self.ln_2 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.mlp = MLP(args)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
@@ -124,7 +123,7 @@ class QwenModel(nn.Module):
|
||||
super().__init__()
|
||||
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
|
||||
self.ln_f = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
|
||||
def __call__(self, inputs, mask=None, cache=None):
|
||||
x = self.wte(inputs)
|
||||
|
||||
Reference in New Issue
Block a user