[mlx-lm] Add precompiled normalizations (#451)

* add precompiled normalizations

* nits
This commit is contained in:
Awni Hannun
2024-02-22 12:40:55 -08:00
committed by GitHub
parent 97c09a863d
commit f24edfa9dc
13 changed files with 74 additions and 105 deletions

View File

@@ -6,6 +6,7 @@ import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .layers import LayerNorm
@dataclass
@@ -27,11 +28,6 @@ class ModelArgs(BaseModelArgs):
self.num_key_value_heads = self.num_attention_heads
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class PhiAttention(nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()