Switch to fast RMS/LN Norm (#603)

* use nn.RMSNorm, use sdpa, cleanup

* bump mlx versions

* minor update

* use fast layer norm

* version bump

* update requirement for whisper

* update requirement for gguf
This commit is contained in:
Awni Hannun
2024-03-23 07:13:51 -07:00
committed by GitHub
parent fbed720d6f
commit b8a348c1b8
44 changed files with 144 additions and 1155 deletions

View File

@@ -1,4 +1,4 @@
mlx>=0.1
mlx>=0.8
numba
numpy
torch

View File

@@ -37,11 +37,6 @@ def sinusoids(length, channels, max_timescale=10000):
return mx.concatenate([mx.sin(scaled_time), mx.cos(scaled_time)], axis=1)
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int):
super().__init__()
@@ -98,17 +93,17 @@ class ResidualAttentionBlock(nn.Module):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.attn_ln = nn.LayerNorm(n_state)
self.cross_attn = (
MultiHeadAttention(n_state, n_head) if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp1 = nn.Linear(n_state, n_mlp)
self.mlp2 = nn.Linear(n_mlp, n_state)
self.mlp_ln = LayerNorm(n_state)
self.mlp_ln = nn.LayerNorm(n_state)
def __call__(self, x, xa=None, mask=None, kv_cache=None):
kv, cross_kv = kv_cache if kv_cache else (None, None)
@@ -140,7 +135,7 @@ class AudioEncoder(nn.Module):
self._positional_embedding = sinusoids(n_ctx, n_state).astype(dtype)
self.blocks = [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
self.ln_post = LayerNorm(n_state)
self.ln_post = nn.LayerNorm(n_state)
def __call__(self, x):
x = nn.gelu(self.conv1(x)).astype(x.dtype)
@@ -174,7 +169,7 @@ class TextDecoder(nn.Module):
ResidualAttentionBlock(n_state, n_head, cross_attention=True)
for _ in range(n_layer)
]
self.ln = LayerNorm(n_state)
self.ln = nn.LayerNorm(n_state)
self._mask = nn.MultiHeadAttention.create_additive_causal_mask(n_ctx).astype(
dtype
)