mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 21:01:32 +08:00
Switch to fast RMS/LN Norm (#603)
* use nn.RMSNorm, use sdpa, cleanup * bump mlx versions * minor update * use fast layer norm * version bump * update requirement for whisper * update requirement for gguf
This commit is contained in:
@@ -45,20 +45,6 @@ class TextConfig:
|
||||
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dims: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.weight = mx.ones((dims,))
|
||||
self.eps = eps
|
||||
|
||||
def _norm(self, x):
|
||||
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
|
||||
|
||||
def __call__(self, x):
|
||||
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
|
||||
return self.weight * output
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: TextConfig):
|
||||
super().__init__()
|
||||
@@ -105,10 +91,6 @@ class Attention(nn.Module):
|
||||
keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if self.repeats > 1:
|
||||
keys = mx.repeat(keys, self.repeats, axis=1)
|
||||
values = mx.repeat(values, self.repeats, axis=1)
|
||||
|
||||
if cache is not None:
|
||||
key_cache, value_cache = cache
|
||||
queries = self.rope(queries, offset=key_cache.shape[2])
|
||||
@@ -119,11 +101,10 @@ class Attention(nn.Module):
|
||||
queries = self.rope(queries)
|
||||
keys = self.rope(keys)
|
||||
|
||||
scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
|
||||
if mask is not None:
|
||||
scores += mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=self.scale, mask=mask
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
return self.o_proj(output), (keys, values)
|
||||
|
||||
|
||||
@@ -145,8 +126,8 @@ class TransformerBlock(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Attention(config)
|
||||
self.mlp = MLP(config.hidden_size, config.intermediate_size)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = nn.RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.config = config
|
||||
@@ -175,7 +156,7 @@ class Llama(nn.Module):
|
||||
self.layers = [
|
||||
TransformerBlock(config=config) for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@@ -1,4 +1,4 @@
|
||||
mlx>=0.5.0
|
||||
mlx>=0.8.0
|
||||
numpy
|
||||
transformers
|
||||
torch
|
||||
|
Reference in New Issue
Block a user