[mlx-lm] Add precompiled normalizations (#451)

* add precompiled normalizations

* nits
This commit is contained in:
Awni Hannun
2024-02-22 12:40:55 -08:00
committed by GitHub
parent 97c09a863d
commit f24edfa9dc
13 changed files with 74 additions and 105 deletions

View File

@@ -5,6 +5,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import RMSNorm
@dataclass
@@ -34,20 +35,6 @@ class ModelArgs(BaseModelArgs):
raise ValueError("rope_scaling 'type' currently only supports 'linear'")
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-6):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return self.weight * output
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()