mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
[mlx-lm] Add precompiled normalizations (#451)
* add precompiled normalizations * nits
This commit is contained in:
@@ -1,17 +1,13 @@
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_unflatten
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from .layers import LayerNorm
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -37,11 +33,6 @@ class ModelArgs:
|
||||
)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
return super().__call__(x.astype(mx.float32)).astype(x.dtype)
|
||||
|
||||
|
||||
class RoPEAttention(nn.Module):
|
||||
def __init__(self, dims: int, num_heads: int, rotary_dim: int):
|
||||
super().__init__()
|
||||
|
Reference in New Issue
Block a user