Switch to fast RMS/LN Norm (#603)

* use nn.RMSNorm, use sdpa, cleanup

* bump mlx versions

* minor update

* use fast layer norm

* version bump

* update requirement for whisper

* update requirement for gguf
This commit is contained in:
Awni Hannun
2024-03-23 07:13:51 -07:00
committed by GitHub
parent fbed720d6f
commit b8a348c1b8
44 changed files with 144 additions and 1155 deletions

View File

@@ -6,7 +6,6 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import LayerNorm
try:
import hf_olmo
@@ -46,8 +45,8 @@ class TransformerBlock(nn.Module):
self.ff_proj = nn.Linear(dim, args.mlp_hidden_size, bias=False)
self.ff_out = nn.Linear(args.mlp_hidden_size // 2, dim, bias=False)
self.att_norm = LayerNorm(dim, affine=False)
self.ff_norm = LayerNorm(dim, affine=False)
self.att_norm = nn.LayerNorm(dim, affine=False)
self.ff_norm = nn.LayerNorm(dim, affine=False)
head_dim = dim // self.n_heads
self.scale = head_dim**-0.5
@@ -120,7 +119,7 @@ class Transformer(nn.Module):
self.blocks = [TransformerBlock(args=args) for _ in range(args.n_layers)]
if not self.weight_tying:
self.ff_out = nn.Linear(args.d_model, args.embedding_size, bias=False)
self.norm = LayerNorm(args.d_model, affine=False)
self.norm = nn.LayerNorm(args.d_model, affine=False)
def __call__(
self,