mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
Switch to fast RMS/LN Norm (#603)
* use nn.RMSNorm, use sdpa, cleanup * bump mlx versions * minor update * use fast layer norm * version bump * update requirement for whisper * update requirement for gguf
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
mlx>=0.1
|
||||
mlx>=0.8
|
||||
numba
|
||||
numpy
|
||||
torch
|
||||
|
@@ -37,11 +37,6 @@ def sinusoids(length, channels, max_timescale=10000):
|
||||
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int):
|
||||
super().__init__()
|
||||
@@ -98,17 +93,17 @@ class ResidualAttentionBlock(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.attn = MultiHeadAttention(n_state, n_head)
|
||||
self.attn_ln = LayerNorm(n_state)
|
||||
self.attn_ln = nn.LayerNorm(n_state)
|
||||
|
||||
self.cross_attn = (
|
||||
MultiHeadAttention(n_state, n_head) if cross_attention else None
|
||||
)
|
||||
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
|
||||
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
|
||||
|
||||
n_mlp = n_state * 4
|
||||
self.mlp1 = nn.Linear(n_state, n_mlp)
|
||||
self.mlp2 = nn.Linear(n_mlp, n_state)
|
||||
self.mlp_ln = LayerNorm(n_state)
|
||||
self.mlp_ln = nn.LayerNorm(n_state)
|
||||
|
||||
def __call__(self, x, xa=None, mask=None, kv_cache=None):
|
||||
kv, cross_kv = kv_cache if kv_cache else (None, None)
|
||||
@@ -140,7 +135,7 @@ class AudioEncoder(nn.Module):
|
||||
self._positional_embedding = sinusoids(n_ctx, n_state).astype(dtype)
|
||||
|
||||
self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
|
||||
self.ln_post = LayerNorm(n_state)
|
||||
self.ln_post = nn.LayerNorm(n_state)
|
||||
|
||||
def __call__(self, x):
|
||||
x = nn.gelu(self.conv1(x)).astype(x.dtype)
|
||||
@@ -174,7 +169,7 @@ class TextDecoder(nn.Module):
|
||||
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
|
||||
for _ in range(n_layer)
|
||||
]
|
||||
self.ln = LayerNorm(n_state)
|
||||
self.ln = nn.LayerNorm(n_state)
|
||||
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(
|
||||
dtype
|
||||
)
|
||||
|
Reference in New Issue
Block a user