[mlx-lm] Add precompiled normalizations (#451)

* add precompiled normalizations

* nits
This commit is contained in:
Awni Hannun
2024-02-22 12:40:55 -08:00
committed by GitHub
parent 97c09a863d
commit f24edfa9dc
13 changed files with 74 additions and 105 deletions

View File

@@ -5,6 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
@@ -26,20 +27,6 @@ class ModelArgs(BaseModelArgs):
self.num_key_value_heads = self.num_attention_heads
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()