[mlx-lm] Add precompiled normalizations (#451)

* add precompiled normalizations

* nits
This commit is contained in:
Awni Hannun
2024-02-22 12:40:55 -08:00
committed by GitHub
parent 97c09a863d
commit f24edfa9dc
13 changed files with 74 additions and 105 deletions

View File

@@ -1,17 +1,13 @@
import glob
import inspect
import json
import math
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Tuple
from dataclasses import dataclass
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from huggingface_hub import snapshot_download
from mlx.utils import tree_unflatten
from transformers import AutoTokenizer
from .layers import LayerNorm
@dataclass
@@ -37,11 +33,6 @@ class ModelArgs:
)
class LayerNorm(nn.LayerNorm):
def __call__(self, x: mx.array) -> mx.array:
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
class RoPEAttention(nn.Module):
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
super().__init__()