mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Switch to fast RMS/LN Norm (#603)
* use nn.RMSNorm, use sdpa, cleanup * bump mlx versions * minor update * use fast layer norm * version bump * update requirement for whisper * update requirement for gguf
This commit is contained in:
@@ -6,7 +6,6 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
|
||||
from .base import BaseModelArgs
|
||||
from .layers import RMSNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -82,9 +81,6 @@ class Attention(nn.Module):
|
||||
|
||||
# expand shared kv
|
||||
assert self.k_num_heads == self.v_num_heads
|
||||
repeats = self.config.n_shared_head
|
||||
key_states = mx.repeat(key_states, repeats, axis=1)
|
||||
value_states = mx.repeat(value_states, repeats, axis=1)
|
||||
|
||||
kv_seq_len = 0
|
||||
if cache is not None:
|
||||
@@ -97,12 +93,14 @@ class Attention(nn.Module):
|
||||
key_states = mx.concatenate([cache[0], key_states], axis=2)
|
||||
value_states = mx.concatenate([cache[1], value_states], axis=2)
|
||||
|
||||
scores = (query_states * self.scale) @ key_states.transpose(0, 1, 3, 2)
|
||||
if attention_mask is not None:
|
||||
scores += attention_mask
|
||||
scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
|
||||
output = (scores @ value_states).transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
|
||||
|
||||
output = mx.fast.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
scale=self.scale,
|
||||
mask=attention_mask,
|
||||
)
|
||||
output = output.transpose(0, 2, 1, 3).reshape(bsz, q_len, -1)
|
||||
return self.o_proj(output), (key_states, value_states)
|
||||
|
||||
|
||||
@@ -127,7 +125,7 @@ class PlamoDecoderLayer(nn.Module):
|
||||
self.hidden_size = config.hidden_size
|
||||
self.self_attn = Attention(config)
|
||||
self.mlp = MLP(config)
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -170,7 +168,7 @@ class PlamoModel(nn.Module):
|
||||
|
||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
||||
self.layers = PlamoDecoder(config) # type: ignore
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user