[mlx-lm] Add precompiled normalizations (#451)

* add precompiled normalizations

* nits
This commit is contained in:
Awni Hannun
2024-02-22 12:40:55 -08:00
committed by GitHub
parent 97c09a863d
commit f24edfa9dc
13 changed files with 74 additions and 105 deletions

View File

@@ -1,4 +1,5 @@
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx
@@ -22,18 +23,21 @@ class ModelArgs(BaseModelArgs):
rope_traditional: bool = False
@partial(mx.compile, shapeless=True)
def rms_norm(x, weight, eps):
x = x.astype(mx.float32)
x = x * mx.rsqrt(x.square().mean(-1, keepdims=True) + eps)
return (1.0 + weight) * x.astype(weight.dtype)
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
self.weight = mx.ones((dims,))
self.eps = eps
def _norm(self, x):
return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)
def __call__(self, x):
output = self._norm(x.astype(mx.float32)).astype(x.dtype)
return (1 + self.weight) * output
return rms_norm(x, self.weight, self.eps)
class Attention(nn.Module):