mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
Switch to fast RMS/LN Norm (#603)
* use nn.RMSNorm, use sdpa, cleanup * bump mlx versions * minor update * use fast layer norm * version bump * update requirement for whisper * update requirement for gguf
This commit is contained in:
@@ -5,7 +5,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -97,7 +96,7 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = LayerNorm(
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
self.args = args
|
||||
@@ -125,7 +124,7 @@ class CohereModel(nn.Module):
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = LayerNorm(
|
||||
self.norm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
|
||||
)
|
||||
|
||||
|
@@ -23,13 +23,6 @@ class ModelArgs(BaseModelArgs):
|
||||
rope_traditional: bool = False
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x, weight, eps):
|
||||
x = x.astype(mx.float32)
|
||||
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return (1.0 + weight) * x.astype(weight.dtype)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
@@ -37,7 +30,7 @@ class RMSNorm(nn.Module):
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
@@ -1,80 +0,0 @@
|
||||
from functools import partial
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def rms_norm(x, weight, eps):
|
||||
x = x.astype(mx.float32)
|
||||
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
|
||||
return weight * x.astype(weight.dtype)
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def __call__(self, x):
|
||||
return rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
@partial(mx.compile, shapeless=True)
|
||||
def ln_norm(x, eps, weight=None, bias=None):
|
||||
"""
|
||||
Layer normalization for input tensor x.
|
||||
|
||||
Args:
|
||||
x (np.ndarray): Input tensor.
|
||||
eps (float, optional): Small value to avoid division by zero.
|
||||
weight (np.ndarray, optional): Weight tensor for normalization.
|
||||
bias (np.ndarray, optional): Bias tensor for normalization.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Normalized tensor.
|
||||
"""
|
||||
t = x.dtype
|
||||
x = x.astype(mx.float32)
|
||||
|
||||
# Compute mean and variance along the last dimension
|
||||
means = mx.mean(x, axis=-1, keepdims=True)
|
||||
var = mx.var(x, axis=-1, keepdims=True)
|
||||
|
||||
# Normalize the input tensor
|
||||
x = (x - means) * mx.rsqrt(var + eps)
|
||||
x = x.astype(t)
|
||||
|
||||
# Apply weight and bias if provided
|
||||
if weight is not None:
|
||||
x = x * weight
|
||||
if bias is not None:
|
||||
x = x + bias
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(
|
||||
self, dims: int, eps: float = 1e-5, affine: bool = True, bias: bool = True
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.dims = dims
|
||||
self.affine = affine
|
||||
|
||||
if affine:
|
||||
self.weight = mx.ones((dims,))
|
||||
self.bias = mx.zeros((dims,)) if bias else None
|
||||
|
||||
def _extra_repr(self):
|
||||
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
if self.affine:
|
||||
if self.bias is not None:
|
||||
return ln_norm(x, self.eps, self.weight, self.bias)
|
||||
else:
|
||||
return ln_norm(x, self.eps, self.weight)
|
||||
else:
|
||||
return ln_norm(x, self.eps)
|
@@ -5,7 +5,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -113,8 +112,10 @@ class TransformerBlock(nn.Module):
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
@@ -141,7 +142,7 @@ class LlamaModel(nn.Module):
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@@ -6,7 +6,6 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -146,7 +145,7 @@ class MixtralSparseMoeBlock(nn.Module):
|
||||
if self.training:
|
||||
mx.eval(inds)
|
||||
inds = np.array(inds)
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
|
||||
for e, expert in enumerate(self.experts):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
@@ -173,8 +172,10 @@ class MixtralDecoderLayer(nn.Module):
|
||||
self.self_attn = MixtralAttention(args)
|
||||
|
||||
self.block_sparse_moe = MixtralSparseMoeBlock(args)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -199,7 +200,7 @@ class MixtralModel(nn.Module):
|
||||
self.layers = [
|
||||
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@@ -6,7 +6,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import LayerNorm
|
||||
|
||||
try:
|
||||
import hf_olmo
|
||||
@@ -46,8 +45,8 @@ class TransformerBlock(nn.Module):
|
||||
self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False)
|
||||
self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False)
|
||||
|
||||
self.att_norm = LayerNorm(dim, affine=False)
|
||||
self.ff_norm = LayerNorm(dim, affine=False)
|
||||
self.att_norm = nn.LayerNorm(dim, affine=False)
|
||||
self.ff_norm = nn.LayerNorm(dim, affine=False)
|
||||
|
||||
head_dim = dim // self.n_heads
|
||||
self.scale = head_dim**-0.5
|
||||
@@ -120,7 +119,7 @@ class Transformer(nn.Module):
|
||||
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
|
||||
if not self.weight_tying:
|
||||
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
|
||||
self.norm = LayerNorm(args.d_model, affine=False)
|
||||
self.norm = nn.LayerNorm(args.d_model, affine=False)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@@ -6,7 +6,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -122,7 +121,9 @@ class PhiDecoderLayer(nn.Module):
|
||||
def __init__(self, config: ModelArgs):
|
||||
super().__init__()
|
||||
self.self_attn = PhiAttention(config=config)
|
||||
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
self.mlp = PhiMLP(config)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
@@ -137,7 +138,9 @@ class PhiModel(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||
self.final_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.final_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
def __call__(self, x, cache):
|
||||
x = self.embed_tokens(x)
|
||||
|
@@ -7,8 +7,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
@@ -116,7 +114,7 @@ class MOE(nn.Module):
|
||||
|
||||
if self.training:
|
||||
ys = []
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
|
||||
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
|
||||
for e, expert in enumerate(self.mlp):
|
||||
idx1, idx2 = map(mx.array, np.where(inds == e))
|
||||
if idx1.size == 0:
|
||||
@@ -141,7 +139,7 @@ class ParallelBlock(nn.Module):
|
||||
dims = config.model_dim
|
||||
mlp_dims = dims * 4
|
||||
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
|
||||
self.ln = LayerNorm(dims)
|
||||
self.ln = nn.LayerNorm(dims)
|
||||
self.moe = MOE(config, dims, mlp_dims)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
@@ -179,7 +177,7 @@ class Embd(nn.Module):
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.ln = LayerNorm(config.model_dim)
|
||||
self.ln = nn.LayerNorm(config.model_dim)
|
||||
self.linear = nn.Linear(config.model_dim, config.num_vocab)
|
||||
|
||||
def __call__(self, inputs):
|
||||
|
@@ -6,7 +6,6 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -82,9 +81,6 @@ class Attention(nn.Module):
|
||||
|
||||
# expand shared kv
|
||||
assert self.k_num_heads == self.v_num_heads
|
||||
repeats = self.config.n_shared_head
|
||||
key_states = mx.repeat(key_states, repeats, axis=1)
|
||||
value_states = mx.repeat(value_states, repeats, axis=1)
|
||||
|
||||
kv_seq_len = 0
|
||||
if cache is not None:
|
||||
@@ -97,12 +93,14 @@ class Attention(nn.Module):
|
||||
key_states = mx.concatenate([cache[0], key_states], axis=2)
|
||||
value_states = mx.concatenate([cache[1], value_states], axis=2)
|
||||
|
||||
scores = (query_states * self.scale) @ key_states.transpose(0, 1, 3, 2)
|
||||
if attention_mask is not None:
|
||||
scores += attention_mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ value_states).transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
scale=self.scale,
|
||||
mask=attention_mask,
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
|
||||
return self.o_proj(output), (key_states, value_states)
|
||||
|
||||
|
||||
@@ -127,7 +125,7 @@ class PlamoDecoderLayer(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Attention(config)
|
||||
self.mlp = MLP(config)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -170,7 +168,7 @@ class PlamoModel(nn.Module):
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = PlamoDecoder(config) # type: ignore
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@@ -5,7 +5,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -102,9 +101,9 @@ class TransformerBlock(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
||||
self.ln_1 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.attn = Attention(args)
|
||||
self.ln_2 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.mlp = MLP(args)
|
||||
|
||||
def __call__(self, x, mask=None, cache=None):
|
||||
@@ -124,7 +123,7 @@ class QwenModel(nn.Module):
|
||||
super().__init__()
|
||||
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
|
||||
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
|
||||
self.ln_f = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
|
||||
|
||||
def __call__(self, inputs, mask=None, cache=None):
|
||||
x = self.wte(inputs)
|
||||
|
@@ -5,7 +5,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -114,8 +113,10 @@ class TransformerBlock(nn.Module):
|
||||
self.hidden_size = args.hidden_size
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
args.hidden_size, eps=args.rms_norm_eps
|
||||
)
|
||||
self.args = args
|
||||
|
||||
def __call__(
|
||||
@@ -142,7 +143,7 @@ class Qwen2Model(nn.Module):
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@@ -6,7 +6,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -120,8 +119,10 @@ class DecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(config=config)
|
||||
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
||||
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.post_attention_layernorm = LayerNorm(
|
||||
self.input_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
@@ -138,7 +139,7 @@ class StableLM(nn.Module):
|
||||
super().__init__()
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)]
|
||||
self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def __call__(self, x, mask, cache):
|
||||
x = self.embed_tokens(x)
|
||||
|
@@ -5,7 +5,6 @@ import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -91,8 +90,8 @@ class TransformerBlock(nn.Module):
|
||||
|
||||
self.self_attn = Attention(args)
|
||||
self.mlp = MLP(args.hidden_size, args.intermediate_size)
|
||||
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
self.post_attention_layernorm = LayerNorm(
|
||||
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
self.post_attention_layernorm = nn.LayerNorm(
|
||||
args.hidden_size, eps=args.norm_epsilon
|
||||
)
|
||||
self.args = args
|
||||
@@ -121,7 +120,7 @@ class Starcoder2Model(nn.Module):
|
||||
self.layers = [
|
||||
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
|
||||
]
|
||||
self.norm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@@ -1,4 +1,4 @@
|
||||
mlx>=0.6
|
||||
mlx>=0.8
|
||||
numpy
|
||||
transformers>=4.38.0
|
||||
protobuf
|
||||
|
Reference in New Issue
Block a user