Switch to fast RMS/LN Norm (#603)

* use nn.RMSNorm, use sdpa, cleanup

* bump mlx versions

* minor update

* use fast layer norm

* version bump

* update requirement for whisper

* update requirement for gguf
This commit is contained in:
Awni Hannun
2024-03-23 07:13:51 -07:00
committed by GitHub
parent fbed720d6f
commit b8a348c1b8
44 changed files with 144 additions and 1155 deletions

View File

@@ -6,7 +6,6 @@ import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
@@ -82,9 +81,6 @@ class Attention(nn.Module):
# expand shared kv
assert self.k_num_heads == self.v_num_heads
repeats = self.config.n_shared_head
key_states = mx.repeat(key_states, repeats, axis=1)
value_states = mx.repeat(value_states, repeats, axis=1)
kv_seq_len = 0
if cache is not None:
@@ -97,12 +93,14 @@ class Attention(nn.Module):
key_states = mx.concatenate([cache[0], key_states], axis=2)
value_states = mx.concatenate([cache[1], value_states], axis=2)
scores = (query_states * self.scale) @ key_states.transpose(0, 1, 3, 2)
if attention_mask is not None:
scores += attention_mask
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
output = (scores @ value_states).transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
output = mx.fast.scaled_dot_product_attention(
query_states,
key_states,
value_states,
scale=self.scale,
mask=attention_mask,
)
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
return self.o_proj(output), (key_states, value_states)
@@ -127,7 +125,7 @@ class PlamoDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size
self.self_attn = Attention(config)
self.mlp = MLP(config)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,
@@ -170,7 +168,7 @@ class PlamoModel(nn.Module):
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = PlamoDecoder(config) # type: ignore
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def __call__(
self,