Switch to fast RMS/LN Norm (#603)

* use nn.RMSNorm, use sdpa, cleanup

* bump mlx versions

* minor update

* use fast layer norm

* version bump

* update requirement for whisper

* update requirement for gguf
This commit is contained in:
Awni Hannun
2024-03-23 07:13:51 -07:00
committed by GitHub
parent fbed720d6f
commit b8a348c1b8
44 changed files with 144 additions and 1155 deletions

View File

@@ -6,7 +6,6 @@ import mlx.nn as nn
import numpy as np
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
@@ -146,7 +145,7 @@ class MixtralSparseMoeBlock(nn.Module):
if self.training:
mx.eval(inds)
inds = np.array(inds)
y = mx.zeros((x.shape[0], ne, x.shape[-1]))
y = mx.zeros((x.shape[0], ne, x.shape[-1]), x.dtype)
for e, expert in enumerate(self.experts):
idx1, idx2 = map(mx.array, np.where(inds == e))
if idx1.size == 0:
@@ -173,8 +172,10 @@ class MixtralDecoderLayer(nn.Module):
self.self_attn = MixtralAttention(args)
self.block_sparse_moe = MixtralSparseMoeBlock(args)
self.input_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.post_attention_layernorm = nn.RMSNorm(
args.hidden_size, eps=args.rms_norm_eps
)
def __call__(
self,
@@ -199,7 +200,7 @@ class MixtralModel(nn.Module):
self.layers = [
MixtralDecoderLayer(args=args) for _ in range(args.num_hidden_layers)
]
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
def __call__(
self,