mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Switch to fast RMS/LN Norm (#603)
* use nn.RMSNorm, use sdpa, cleanup * bump mlx versions * minor update * use fast layer norm * version bump * update requirement for whisper * update requirement for gguf
This commit is contained in:
@@ -6,7 +6,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -120,8 +119,10 @@ class DecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(config=config)
|
||||
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
||||
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = LayerNorm(
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
@@ -138,7 +139,7 @@ class StableLM(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||
self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embed_tokens(x)
|
||||
|
Reference in New Issue
Block a user