Switch to fast RMS/LN Norm (#603)

* use nn.RMSNorm, use sdpa, cleanup

* bump mlx versions

* minor update

* use fast layer norm

* version bump

* update requirement for whisper

* update requirement for gguf
This commit is contained in:
Awni Hannun
2024-03-23 07:13:51 -07:00
committed by GitHub
parent fbed720d6f
commit b8a348c1b8
44 changed files with 144 additions and 1155 deletions

View File

@@ -5,7 +5,6 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import LayerNorm
@dataclass
@@ -97,7 +96,7 @@ class TransformerBlock(nn.Module):
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = LayerNorm(
self.input_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)
self.args = args
@@ -125,7 +124,7 @@ class CohereModel(nn.Module):
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = LayerNorm(
self.norm = nn.LayerNorm(
args.hidden_size, eps=args.layer_norm_eps, bias=args.layer_norm_bias
)

View File

@@ -23,13 +23,6 @@ class ModelArgs(BaseModelArgs):
rope_traditional: bool = False
@partial(mx.compile, shapeless=True)
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return (1.0 + weight) * x.astype(weight.dtype)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
@@ -37,7 +30,7 @@ class RMSNorm(nn.Module):
self.eps = eps
def __call__(self, x):
return rms_norm(x, self.weight, self.eps)
return mx.fast.rms_norm(x, 1.0 + self.weight, self.eps)
class Attention(nn.Module):

View File

@@ -1,80 +0,0 @@
from functools import partial
import mlx.core as mx
import mlx.nn as nn
@partial(mx.compile, shapeless=True)
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return weight * x.astype(weight.dtype)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def __call__(self, x):
return rms_norm(x, self.weight, self.eps)
@partial(mx.compile, shapeless=True)
def ln_norm(x, eps, weight=None, bias=None):
"""
Layer normalization for input tensor x.
Args:
x (np.ndarray): Input tensor.
eps (float, optional): Small value to avoid division by zero.
weight (np.ndarray, optional): Weight tensor for normalization.
bias (np.ndarray, optional): Bias tensor for normalization.
Returns:
np.ndarray: Normalized tensor.
"""
t = x.dtype
x = x.astype(mx.float32)
# Compute mean and variance along the last dimension
means = mx.mean(x, axis=-1, keepdims=True)
var = mx.var(x, axis=-1, keepdims=True)
# Normalize the input tensor
x = (x - means) * mx.rsqrt(var + eps)
x = x.astype(t)
# Apply weight and bias if provided
if weight is not None:
x = x * weight
if bias is not None:
x = x + bias
return x
class LayerNorm(nn.Module):
def __init__(
self, dims: int, eps: float = 1e-5, affine: bool = True, bias: bool = True
):
super().__init__()
self.eps = eps
self.dims = dims
self.affine = affine
if affine:
self.weight = mx.ones((dims,))
self.bias = mx.zeros((dims,)) if bias else None
def _extra_repr(self):
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
def __call__(self, x: mx.array) -> mx.array:
if self.affine:
if self.bias is not None:
return ln_norm(x, self.eps, self.weight, self.bias)
else:
return ln_norm(x, self.eps, self.weight)
else:
return ln_norm(x, self.eps)

View File

@@ -5,7 +5,6 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
@@ -113,8 +112,10 @@ class TransformerBlock(nn.Module):
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
@@ -141,7 +142,7 @@ class LlamaModel(nn.Module):
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,

View File

@@ -6,7 +6,6 @@ import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
@@ -146,7 +145,7 @@ class MixtralSparseMoeBlock(nn.Module):
if self.training:
mx.eval(inds)
inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
@@ -173,8 +172,10 @@ class MixtralDecoderLayer(nn.Module):
self.self_attn = MixtralAttention(args)
self.block_sparse_moe = MixtralSparseMoeBlock(args)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
@@ -199,7 +200,7 @@ class MixtralModel(nn.Module):
self.layers = [
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,

View File

@@ -6,7 +6,6 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import LayerNorm
try:
import hf_olmo
@@ -46,8 +45,8 @@ class TransformerBlock(nn.Module):
self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False)
self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False)
self.att_norm = LayerNorm(dim, affine=False)
self.ff_norm = LayerNorm(dim, affine=False)
self.att_norm = nn.LayerNorm(dim, affine=False)
self.ff_norm = nn.LayerNorm(dim, affine=False)
head_dim = dim // self.n_heads
self.scale = head_dim**-0.5
@@ -120,7 +119,7 @@ class Transformer(nn.Module):
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
if not self.weight_tying:
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
self.norm = LayerNorm(args.d_model, affine=False)
self.norm = nn.LayerNorm(args.d_model, affine=False)
def __call__(
self,

View File

@@ -6,7 +6,6 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import LayerNorm
@dataclass
@@ -122,7 +121,9 @@ class PhiDecoderLayer(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.self_attn = PhiAttention(config=config)
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.input_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.mlp = PhiMLP(config)
def __call__(self, x, mask, cache):
@@ -137,7 +138,9 @@ class PhiModel(nn.Module):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [PhiDecoderLayer(config) for i in range(config.num_hidden_layers)]
self.final_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.final_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
def __call__(self, x, cache):
x = self.embed_tokens(x)

View File

@@ -7,8 +7,6 @@ import mlx.core as mx
import mlx.nn as nn
import numpy as np
from .layers import LayerNorm
@dataclass
class ModelArgs:
@@ -116,7 +114,7 @@ class MOE(nn.Module):
if self.training:
ys = []
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
for e, expert in enumerate(self.mlp):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
@@ -141,7 +139,7 @@ class ParallelBlock(nn.Module):
dims = config.model_dim
mlp_dims = dims * 4
self.mixer = RoPEAttention(dims, config.num_heads, config.rotary_dim)
self.ln = LayerNorm(dims)
self.ln = nn.LayerNorm(dims)
self.moe = MOE(config, dims, mlp_dims)
def __call__(self, x, mask, cache):
@@ -179,7 +177,7 @@ class Embd(nn.Module):
class OutputHead(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
self.ln = LayerNorm(config.model_dim)
self.ln = nn.LayerNorm(config.model_dim)
self.linear = nn.Linear(config.model_dim, config.num_vocab)
def __call__(self, inputs):

View File

@@ -6,7 +6,6 @@ import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
@@ -82,9 +81,6 @@ class Attention(nn.Module):
# expand shared kv
assert self.k_num_heads == self.v_num_heads
repeats = self.config.n_shared_head
key_states = mx.repeat(key_states, repeats, axis=1)
value_states = mx.repeat(value_states, repeats, axis=1)
kv_seq_len = 0
if cache is not None:
@@ -97,12 +93,14 @@ class Attention(nn.Module):
key_states = mx.concatenate([cache[0], key_states], axis=2)
value_states = mx.concatenate([cache[1], value_states], axis=2)
scores = (query_states * self.scale) @ key_states.transpose(0, 1, 3, 2)
if attention_mask is not None:
scores += attention_mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ value_states).transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
output = mx.fast.scaled_dot_product_attention(
query_states,
key_states,
value_states,
scale=self.scale,
mask=attention_mask,
)
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
return self.o_proj(output), (key_states, value_states)
@@ -127,7 +125,7 @@ class PlamoDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size
self.self_attn = Attention(config)
self.mlp = MLP(config)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
@@ -170,7 +168,7 @@ class PlamoModel(nn.Module):
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = PlamoDecoder(config) # type: ignore
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,

View File

@@ -5,7 +5,6 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
@@ -102,9 +101,9 @@ class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.ln_1 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.ln_1 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.attn = Attention(args)
self.ln_2 = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.ln_2 = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.mlp = MLP(args)
def __call__(self, x, mask=None, cache=None):
@@ -124,7 +123,7 @@ class QwenModel(nn.Module):
super().__init__()
self.wte = nn.Embedding(args.vocab_size, args.hidden_size)
self.h = [TransformerBlock(args) for _ in range(args.num_hidden_layers)]
self.ln_f = RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
self.ln_f = nn.RMSNorm(args.hidden_size, eps=args.layer_norm_epsilon)
def __call__(self, inputs, mask=None, cache=None):
x = self.wte(inputs)

View File

@@ -5,7 +5,6 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
@@ -114,8 +113,10 @@ class TransformerBlock(nn.Module):
self.hidden_size = args.hidden_size
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
self.args = args
def __call__(
@@ -142,7 +143,7 @@ class Qwen2Model(nn.Module):
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,

View File

@@ -6,7 +6,6 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import LayerNorm
@dataclass
@@ -120,8 +119,10 @@ class DecoderLayer(nn.Module):
super().__init__()
self.self_attn = Attention(config=config)
self.mlp = MLP(config.hidden_size, config.intermediate_size)
self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.post_attention_layernorm = LayerNorm(
self.input_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.post_attention_layernorm = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
@@ -138,7 +139,7 @@ class StableLM(nn.Module):
super().__init__()
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = [DecoderLayer(config) for i in range(config.num_hidden_layers)]
self.norm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def __call__(self, x, mask, cache):
x = self.embed_tokens(x)

View File

@@ -5,7 +5,6 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import LayerNorm
@dataclass
@@ -91,8 +90,8 @@ class TransformerBlock(nn.Module):
self.self_attn = Attention(args)
self.mlp = MLP(args.hidden_size, args.intermediate_size)
self.input_layernorm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)
self.post_attention_layernorm = LayerNorm(
self.input_layernorm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
self.post_attention_layernorm = nn.LayerNorm(
args.hidden_size, eps=args.norm_epsilon
)
self.args = args
@@ -121,7 +120,7 @@ class Starcoder2Model(nn.Module):
self.layers = [
TransformerBlock(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = LayerNorm(args.hidden_size, eps=args.norm_epsilon)
self.norm = nn.LayerNorm(args.hidden_size, eps=args.norm_epsilon)
def __call__(
self,

View File

@@ -1,4 +1,4 @@
mlx>=0.6
mlx>=0.8
numpy
transformers>=4.38.0
protobuf